在Pytorch中计算自己模型的FLOPs方式

yipeiwu_com6年前Python基础

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

以上这篇在Pytorch中计算自己模型的FLOPs方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

浅谈Tensorflow 动态双向RNN的输出问题

浅谈Tensorflow 动态双向RNN的输出问题

tf.nn.bidirectional_dynamic_rnn()函数:def bidirectional_dynamic_rnn(   cell_fw,&...

python将秒数转化为时间格式的实例

1、转化成时间格式 seconds =35400 m, s = divmod(seconds, 60) h, m = divmod(m, 60) print("%d:%02d:%02...

python 生成器协程运算实例

一、yield运行方式 我们定义一个如下的生成器: def put_on(name): print("Hi {}, 货物来了,准备搬到仓库!".format(name)) wh...

python实现字典嵌套列表取值

如下所示: dict={'log_id': 5891599090191187877, 'result_num': 1, 'result': [{'probability': 0.98...

python-opencv获取二值图像轮廓及中心点坐标的代码

python-opencv获取二值图像轮廓及中心点坐标代码: groundtruth = cv2.imread(groundtruth_path)[:, :, 0] h1, w1 =...