在Pytorch中计算自己模型的FLOPs方式

yipeiwu_com6年前Python基础

https://github.com/Lyken17/pytorch-OpCounter

安装方法很简单:

pip install thop

基本用法:

from torchvision.models import resnet50from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224))

对自己的module进行特别的计算:

class YourModule(nn.Module):
# your definition
def count_your_model(model, x, y):
# your rule
hereflops, params = profile(model, input_size=(1, 3, 224,224),
custom_ops={YourModule: count_your_model})

以上这篇在Pytorch中计算自己模型的FLOPs方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python scipy求解非线性方程的方法(fsolve/root)

python scipy求解非线性方程的方法(fsolve/root)

使用scipy.optimize模块的root和fsolve函数进行数值求解线性及非线性方程,下面直接贴上代码,代码很简单 from scipy.integrate import o...

Python实现的计算马氏距离算法示例

Python实现的计算马氏距离算法示例

本文实例讲述了Python实现的计算马氏距离算法。分享给大家供大家参考,具体如下: 我给写成函数调用了 python实现马氏距离源代码: # encoding: utf-8 fro...

理想高通滤波实现Python opencv示例

理想高通滤波实现Python opencv示例

理想高通滤波实现 python opencv import numpy as np import cv2 from matplotlib import pyplot as plt...

Python imutils 填充图片周边为黑色的实现

Python imutils 填充图片周边为黑色的实现

代码 import imutils import cv2 image = cv2.imread('') # translate the image x=25 pixels to t...

python输出电脑上所有的串口名的方法

python输出电脑上所有的串口名的方法

输出电脑上所有的串口名: import serial import serial.tools.list_ports from easygui import * port_list...