Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

Python读取csv文件分隔符设置方法

Windows下的分隔符默认的是逗号,而MAC的分隔符是分号。拿到一份用分号分割的CSV文件,在Win下是无法正确读取的,因为CSV模块默认调用的是Excel的规则。 所以我们在读取文件...

Django框架基础模板标签与filter使用方法详解

Django框架基础模板标签与filter使用方法详解

本文实例讲述了Django框架基础模板标签与filter使用方法。分享给大家供大家参考,具体如下: 一、基本的模板语言 1、变量 {{ }} 1.1、进入Django shell 环境...

Python中的取模运算方法

Python中的取模运算方法

所谓取模运算,就是计算两个数相除之后的余数,符号是%。如a % b就是计算a除以b的余数。用数学语言来描述,就是如果存在整数n和m,其中0 <= m < b,使得a = n...

值得收藏的10道python 面试题

值得收藏的10道python 面试题

Q1:PEP8是什么?Python之禅(import this)是什么? 这题是考察你对编码规范的认识,无论是自己写代码还是在团队中写代码,了解并遵循代码规范是很基础的要求。企业中在提交...

利用Python命令行传递实例化对象的方法

一、前言 在开发过程中,遇到了这样一个情况:我们需要在脚本中通过 suprocess.call 方法来启动另外一个脚本(脚本 B),当然啦,还得传递一些参数。在这些参数中,有一个需要传...