Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

python一行sql太长折成多行并且有多个参数的方法

sql语句 有一个非常长的sql,用编辑器打开编写的时候太长了导致编写非常吃力,而且容易错乱,我想做的是把A,B,C三个变量赋值到sql中的字段中去 A=1 B=2 C=3 sql...

Python 条件判断的缩写方法

return (1==1) ? "is easy" : "my god" //C...

调用其他python脚本文件里面的类和方法过程解析

这篇文章主要介绍了调用其他python脚本文件里面的类和方法过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 问题描述: 自己编...

在Python的Flask框架中构建Web表单的教程

在Python的Flask框架中构建Web表单的教程

尽管Flask的request对象提供的支持足以处理web表单,但依然有许多任务会变得单调且重复。表单的HTML代码生成和验证提交的表单数据就是两个很好的例子。 Flask-WTF扩展使...

Python 一键获取百度网盘提取码的方法

Python 一键获取百度网盘提取码的方法

该 GIF 图来自于官网,文末有给出链接。 描述 依托于百度网盘巨大的的云存储空间,绝大数人会习惯性的将一些资料什么的存储到上面,但是有的私密链接需要提取码,但是让每个想下载私密资源的...