在Pytorch中计算自己模型的FLOPs方式

yipeiwu_com6年前Python基础

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

以上这篇在Pytorch中计算自己模型的FLOPs方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 的类、继承和多态详解

类的定义 假如要定义一个类 Point,表示二维的坐标点: # point.py class Point: def __init__(self, x=0, y=0): s...

python使用scrapy解析js示例

复制代码 代码如下:from selenium import selenium class MySpider(CrawlSpider):    name =...

Python中对列表排序实例

很多时候,我们需要对List进行排序,Python提供了两个方法,对给定的List L进行排序: 方法1.用List的成员函数sort进行排序 方法2.用built-in函数sorted...

Python3常用内置方法代码实例

这篇文章主要介绍了Python3常用内置方法代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 max()/min() 传...

python遍历序列enumerate函数浅析

enumerate函数用于遍历序列中的元素以及它们的下标。 enumerate函数说明: 函数原型:enumerate(sequence, [start=0]) 功能:将可循环序列seq...