Pytorch修改ResNet模型全连接层进行直接训练实例

yipeiwu_com6年前Python基础

之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把

最后一层的输出改一下,不需要加载前面层的权重,方法如下:

model = torchvision.models.resnet18(pretrained=False)
num_fc_ftr = model.fc.in_features
model.fc = torch.nn.Linear(num_fc_ftr, 224)
model = nn.DataParallel(model, device_ids=config.gpus).to(device)

首先模型结构是必须要传入的,然后把最后一层的输出改为自己所需的数目

以上知识点很简单,大家可以测试下,感谢大家的阅读和对【听图阁-专注于Python设计】的支持。

相关文章

Python使用Pycrypto库进行RSA加密的方法详解

密码与通信 密码技术是一门历史悠久的技术。信息传播离不开加密与解密。密码技术的用途主要源于两个方面,加密/解密和签名/验签 在信息传播中,通常有发送者,接受者和窃听者三个角色。假设发送者...

Python中按值来获取指定的键

Python字典中的键是唯一的,但不同的键可以对应同样的值,比如说uid,可以是1001。id同样可以是1001。这样的话通过值来获取指定的键,就不止一个!而且也并不太好处理。这里同样提...

基于python list对象中嵌套元组使用sort时的排序方法

在list中嵌套元组,在进行sort排序的时候,产生的是原数组的副本,排序过程中,先根据第一个字段进行从小到大排序,如果第一个字段相同的话,再根据第二个字段进行排序,依次类推,当涉及到字...

Gauss-Seidel迭代算法的Python实现详解

import numpy as np import time 1.1 Gauss-Seidel迭代算法 def GaussSeidel_tensor_V2(A,b,Delta,...

在linux下实现 python 监控usb设备信号

1. linux下消息记录 关于系统的各种消息一般都会记录在/var/log/messages文件中,有些主机在中默认情况下有可能没有启用,具体配置方法可参考下面这篇博客: 系统日志配置...