Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

Python程序中设置HTTP代理

0x00 前言 大家对HTTP代理应该都非常熟悉,它在很多方面都有着极为广泛的应用。HTTP代理分为正向代理和反向代理两种,后者一般用于将防火墙后面的服务提供给用户访问或者进行负载均衡,...

pygame实现俄罗斯方块游戏

pygame实现俄罗斯方块游戏

本文实例为大家分享了pygame实现俄罗斯方块的具体代码,供大家参考,具体内容如下 import random, time, pygame, sys from pygame.loca...

Python面向对象编程中的类和对象学习教程

Python中一切都是对象。类提供了创建新类型对象的机制。这篇教程中,我们不谈类和面向对象的基本知识,而专注在更好地理解Python面向对象编程上。假设我们使用新风格的python类,它...

python cv2截取不规则区域图片实例

python cv2截取不规则区域图片实例

知识掌握 cv2.threshold()函数: 设置固定级别的阈值应用于多通道矩阵,将灰度图像变换二值图像,或去除指定级别的噪声,或过滤掉过小或者过大的像素点。 Python: c...

python中cPickle类使用方法详解

在python中,一般可以使用pickle类来进行python对象的序列化,而cPickle提供了一个更快速简单的接口,如python文档所说的:“cPickle – A faster...