在Pytorch中计算自己模型的FLOPs方式

yipeiwu_com6年前Python基础

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

以上这篇在Pytorch中计算自己模型的FLOPs方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

通过实例学习Python Excel操作

通过实例学习Python Excel操作

这篇文章主要介绍了通过实例学习Python Excel操作,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 1.python 读取Exc...

Python获取时间范围内日期列表和周列表的函数

Python获取时间范围内日期列表和周列表的函数

Python获取时间范围内日期列表和周列表的函数 1、获取日期列表 # -*- coding=utf-8 -*- import datetime def dateRange(beg...

Python 使用 PyMysql、DBUtils 创建连接池提升性能

Python 使用 PyMysql、DBUtils 创建连接池提升性能

Python 编程中可以使用 PyMysql 进行数据库的连接及诸如查询/插入/更新等操作,但是每次连接 MySQL 数据库请求时,都是独立的去请求访问,相当浪费资源,而且访问数量达到一...

解决Python传递中文参数的问题

今天有个需要需要传递中文参数给URL 但是在GBK环境下的脚本传递GBK的参数老是给我报UNICODE的解码错误。烦的很。 所以我们果断选择用urlencode来处理中文, 由于国内外网...

详解numpy.meshgrid()方法使用

详解numpy.meshgrid()方法使用

一句话解释numpy.meshgrid()——生成网格点坐标矩阵。 关键词:网格点,坐标矩阵 网格点是什么?坐标矩阵又是什么鬼? 看个图就明白了: 图中,每个交叉点都是网格点,描...