Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

使用Python操作Elasticsearch数据索引的教程

使用Python操作Elasticsearch数据索引的教程

Elasticsearch是一个分布式、Restful的搜索及分析服务器,Apache Solr一样,它也是基于Lucence的索引服务器,但我认为Elasticsearch对比Solr...

python使用matplotlib模块绘制多条折线图、散点图

python使用matplotlib模块绘制多条折线图、散点图

今天想直观的展示一下数据就用到了matplotlib模块,之前都是一张图只有一条曲线,现在想同一个图片上绘制多条曲线来对比,实现很简单,具体如下: #!usr/bin/env pyt...

python通过正则查找微博@(at)用户的方法

本文实例讲述了python通过正则查找微博@(at)用户的方法。分享给大家供大家参考。具体如下: 这段代码用到了python正则的findall方法,查找所有被@的用户,使用数组形式返回...

Python插件virtualenv搭建虚拟环境

Python插件virtualenv搭建虚拟环境

这里想象一下需求,写一个项目使用的一系列1.0版本的插件,现在要新写一个项目,需要用这些插件的2.0版本,该怎么办?都更新成2.0版本?这样之前的项目都没法维护了 这时我们需要一个虚拟环...

python实现字典嵌套列表取值

如下所示: dict={'log_id': 5891599090191187877, 'result_num': 1, 'result': [{'probability': 0.98...