两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试

2019 年 7 月 7 日 CVer

点击上方“CVer”,选择加"星标"或“置顶”

重磅干货,第一时间送达

作者:思源

本文转载自:机器之心

你的模型到底有多少参数,每秒的浮点运算到底有多少,这些你都知道吗?近日,GitHub 开源了一个小工具,它可以统计 PyTorch 模型的参数量与每秒浮点运算数(FLOPs)。有了这两种信息,模型大小控制也就更合理了。

其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。


此外,机器学习还有很多结构没有参数但存在计算,例如最大池化Dropout 等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。


  • PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter


OpCouter


PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。


我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。


对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:

from torchvision.models import resnet50from thop import profile
model = resnet50()input = torch.randn(1, 3, 224, 224)flops, params = profile(model, inputs=(input, ))


我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。



最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。

flops: 2914598912.0parameters: 7978856.0


OpCouter 是怎么算的


我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:

def count_conv2d(m, x, y):    x = x[0]
cin = m.in_channels cout = m.out_channels kh, kw = m.kernel_size batch_size = x.size()[0]
out_h = y.size(2) out_w = y.size(3)
# ops per output element # kernel_mul = kh * kw * cin # kernel_add = kh * kw * cin - 1 kernel_ops = multiply_adds * kh * kw bias_ops = 1 if m.bias is not None else 0 ops_per_element = kernel_ops + bias_ops
# total ops # num_out_elements = y.numel() output_elements = batch_size * out_w * out_h * cout total_ops = output_elements * ops_per_element * cin // m.groups
    m.total_ops = torch.Tensor([int(total_ops)])


总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。


定制你的运算统计


有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。

class YourModule(nn.Module):    # your definitiondef count_your_model(model, x, y):    # your rule here
input = torch.randn(1, 3, 224, 224)flops, params = profile(model, inputs=(input, ),                        custom_ops={YourModule: count_your_model})


最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量:



CVer学术交流群


扫码添加CVer助手,可申请加入CVer-目标检测交流群、图像分割、目标跟踪、人脸检测&识别、OCR、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶和剪枝&压缩等群。一定要备注:研究方向+地点+学校/公司+昵称(如目标检测+上海+上交+卡卡)

▲长按加群


▲长按关注我们

麻烦给我一个在看

登录查看更多
1

相关内容

统计模型[stochasticmodel;statisticmodel;probabilitymodel]指以概率论为基础,采用数学统计方法建立的模型。有些过程无法用理论分析方法导出其模型,但可通过试验测定数据,经过数理统计法求得各变量之间的函数关系,称为统计模型。常用的数理统计分析方法有最大事后概率估算法、最大似然率辨识法等。常用的统计模型有一般线性模型、广义线性模型和混合模型。统计模型的意义在对大量随机事件的规律性做推断时仍然具有统计性,因而称为统计推断。
【伯克利】再思考 Transformer中的Batch Normalization
专知会员服务
41+阅读 · 2020年3月21日
2020图机器学习GNN的四大研究趋势,21篇论文下载
专知会员服务
136+阅读 · 2020年2月10日
模型压缩究竟在做什么?我们真的需要模型压缩么?
专知会员服务
28+阅读 · 2020年1月16日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
深度神经网络模型压缩与加速综述
专知会员服务
129+阅读 · 2019年10月12日
谷歌EfficientNet缩放模型,PyTorch实现登热榜
机器学习算法与Python学习
11+阅读 · 2019年6月4日
100行Python代码,轻松搞定神经网络
大数据文摘
4+阅读 · 2019年5月2日
超强干货!TensorFlow易用代码大集合...
机器学习算法与Python学习
6+阅读 · 2019年2月20日
超全总结:神经网络加速之量化模型 | 附带代码
一次 PyTorch 的踩坑经历,以及如何避免梯度成为NaN
Caffe 深度学习框架上手教程
黑龙江大学自然语言处理实验室
14+阅读 · 2016年6月12日
Sparse Sequence-to-Sequence Models
Arxiv
5+阅读 · 2019年5月14日
Arxiv
11+阅读 · 2018年5月13日
Arxiv
3+阅读 · 2017年10月1日
Arxiv
5+阅读 · 2017年7月23日
VIP会员
相关资讯
谷歌EfficientNet缩放模型,PyTorch实现登热榜
机器学习算法与Python学习
11+阅读 · 2019年6月4日
100行Python代码,轻松搞定神经网络
大数据文摘
4+阅读 · 2019年5月2日
超强干货!TensorFlow易用代码大集合...
机器学习算法与Python学习
6+阅读 · 2019年2月20日
超全总结:神经网络加速之量化模型 | 附带代码
一次 PyTorch 的踩坑经历,以及如何避免梯度成为NaN
Caffe 深度学习框架上手教程
黑龙江大学自然语言处理实验室
14+阅读 · 2016年6月12日
Top
微信扫码咨询专知VIP会员