在Pytorch中计算自己模型的FLOPs方式

yipeiwu_com6年前Python基础

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

以上这篇在Pytorch中计算自己模型的FLOPs方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python执行子进程实现进程间通信的方法

本文实例讲述了python执行子进程实现进程间通信的方法。分享给大家供大家参考。具体实现方法如下: a.py: import subprocess, time subproc = s...

python一键去抖音视频水印工具

python一键去抖音视频水印工具

无水印视频下载 方法一: 无水印视频下载很简单,有一个通用的方法,就是使用去水印平台即可。 我使用的去水印平台是:http://douyin.iiilab.com/ 在输入框中输入视频链...

浅析python继承与多重继承

记住以下几点: 直接子类化内置类型(如dict,list或str)容易出错,因为内置类型的方法通常会忽略用户覆盖的方法,不要子类化内置类型,用户自定义的类应该继承collections模...

轻松掌握python设计模式之访问者模式

轻松掌握python设计模式之访问者模式

本文实例为大家分享了python访问者模式代码,供大家参考,具体内容如下 """访问者模式""" class Node(object): pass class A(Node):...

Python用imghdr模块识别图片格式实例解析

imghdr模块 功能描述:imghdr模块用于识别图片的格式。它通过检测文件的前几个字节,从而判断图片的格式。 唯一一个API imghdr.what(file, h=None) 第一...