Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

opencv python统计及绘制直方图的方法

opencv python统计及绘制直方图的方法

灰度直方图概括了图像的灰度级信息,简单的来说就是每个灰度级图像中的像素个数以及占有率,创建直方图无外乎两个步骤,统计直方图数据,再用绘图库绘制直方图。 统计直方图数据 首先要稍微理解一些...

pytorch:实现简单的GAN示例(MNIST数据集)

我就废话不多说了,直接上代码吧! # -*- coding: utf-8 -*- """ Created on Sat Oct 13 10:22:45 2018 @author: w...

Python计算时间间隔(精确到微妙)的代码实例

Python计算时间间隔(精确到微妙)的代码实例

使用python中的datetime import datetime oldtime=datetime.datetime.now() print oldtime; x=1 while...

Python 3.x基于Xml数据的Http请求方法

Python 3.x基于Xml数据的Http请求方法

1. 前言 由于公司的一个项目是基于B/S架构与WEB服务通信,使用XML数据作为通信数据,在添加新功能时,WEB端与客户端分别由不同的部门负责,所以在WEB端功能实现过程中,需要自己发...

python获取微信小程序手机号并绑定遇到的坑

python获取微信小程序手机号并绑定遇到的坑

最近在做小程序开发,在其中也遇到了很多的坑,获取小程序的手机号并绑定就遇到了一个很傻的坑。 流程介绍 官方流程图 小程序使用方法 需要将 <button> 组件 open...