pytorch 更改预训练模型网络结构的方法

yipeiwu_com6年前Python基础

一个继承nn.module的model它包含一个叫做children()的函数,这个函数可以用来提取出model每一层的网络结构,在此基础上进行修改即可,修改方法如下(去除后两层):

resnet_layer = nn.Sequential(*list(model.children())[:-2])

那么,接下来就可以构建我们的网络了:

class Net(nn.Module):
  def __init__(self , model):
    super(Net, self).__init__()
    #取掉model的后两层
    self.resnet_layer = nn.Sequential(*list(model.children())[:-2])
    
    self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
    self.pool_layer = nn.MaxPool2d(32) 
    self.Linear_layer = nn.Linear(2048, 8)
    
  def forward(self, x):
    x = self.resnet_layer(x)
 
    x = self.transion_layer(x)
 
    x = self.pool_layer(x)
 
    x = x.view(x.size(0), -1) 
 
    x = self.Linear_layer(x)
    
    return x

最后,构建一个对象,并加载resnet预训练的参数就可以啦~

resnet = models.resnet50(pretrained=True)
model = Net(resnet)

以上这篇pytorch 更改预训练模型网络结构的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python输出电脑上所有的串口名的方法

python输出电脑上所有的串口名的方法

输出电脑上所有的串口名: import serial import serial.tools.list_ports from easygui import * port_list...

Python 字符串定义

例如:'string'、"string"、"""string"""或者是'''string'''。在使用上,单引号和双引号没有什么区别。三引号的主要功能是在字符串中可以包含换行。也就是说...

Python登录并获取CSDN博客所有文章列表代码实例

Python登录并获取CSDN博客所有文章列表代码实例

分析登录过程 这几天研究百度登录和贴吧签到,这百度果然是互联网巨头,一个登录过程都弄得复杂无比,简直有毒。我研究了好几天仍然没搞明白。所以还是先挑一个软柿子捏捏,就选择CSDN了。 过程...

学生信息管理系统python版

本文实例为大家分享了python学生信息管理系统的具体代码,供大家参考,具体内容如下 #!/usr/bin/env python # @Time : 2018/3/30 17:37...

Django用户认证系统 组与权限解析

Django的权限系统很简单,它可以赋予users或groups中的users以权限。 Django admin后台就使用了该权限系统,不过也可以用到你自己的代码中。 User对象具有两...