pytorch 修改预训练model实例

yipeiwu_com6年前Python基础

我就废话不多说了,直接上代码吧!

 class Net(nn.Module):
  def __init__(self , model):
   super(Net, self).__init__()
   #取掉model的后两层
   self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
   self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
   self.pool_layer = nn.MaxPool2d(32) 
   self.Linear_layer = nn.Linear(2048, 8)
   
  def forward(self, x):
   x = self.resnet_layer(x)
   x = self.transion_layer(x)
   x = self.pool_layer(x)
   x = x.view(x.size(0), -1) 
   x = self.Linear_layer(x) 
   return x
resnet = models.resnet50(pretrained=True)
model = Net(resnet)

以上这篇pytorch 修改预训练model实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python对list列表结构中的值进行去重的方法总结

今天遇到一个问题,在同事随意的提示下,用了 itertools.groupby 这个函数。不过这个东西最终还是没用上。 问题就是对一个list中的新闻id进行去重,去重之后要保证顺序不变...

python 内置模块详解

一.random模块  随机       random()    随机小数 ...

Python MongoDB 插入数据时已存在则不执行,不存在则插入的解决方法

本文实例讲述了Python MongoDB 插入数据时已存在则不执行,不存在则插入的解决方法。分享给大家供大家参考,具体如下: 前言: 想把QQ日志爬虫(Python)爬下来的日志保存到...

python pycurl验证basic和digest认证的方法

简介 pycurl类似于Python的urllib,但是pycurl是对libcurl的封装,速度更快。 本文使用的是pycurl 7.43.0.1版本。 Apache下配置Basic认...

Python 中pandas索引切片读取数据缺失数据处理问题

Python 中pandas索引切片读取数据缺失数据处理问题

引入   numpy已经能够帮助我们处理数据,能够结合matplotlib解决我们数据分析的问题,那么pandas学习的目的在什么地方呢? numpy能够帮我们处理处理数值型数据,但是这...