pytorch 固定部分参数训练的方法

yipeiwu_com5年前Python基础

需要自己过滤

optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

另外,如果是Variable,则可以初始化时指定

j = Variable(torch.randn(5,5), requires_grad=True)

但是如果是

m = nn.Linear(10,10)

是没有requires_grad传入的

m.requires_grad也没有

需要

for i in m.parameters():
  i.requires_grad=False

另外一个小技巧就是在nn.Module里,可以在中间插入这个

for p in self.parameters():
  p.requires_grad=False

这样前面的参数就是False,而后面的不变

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 6, 5)
    self.conv2 = nn.Conv2d(6, 16, 5)

    for p in self.parameters():
      p.requires_grad=False

    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

以上这篇pytorch 固定部分参数训练的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python输出带颜色的字符串实例

Python输出带颜色的字符串实例

输出带颜色的字符串,用来显示要突出的部分。经测验,在pycharm中可行,在windows命令行中不可行。原因未知。 方法: 格式:"\033[显示方式;前景色;背景色m 需要变颜色...

详解Python3 pandas.merge用法

详解Python3 pandas.merge用法

摘要 数据分析与建模的时候大部分时间在数据准备上,包括对数据的加载、清理、转换以及重塑。pandas提供了一组高级的、灵活的、高效的核心函数,能够轻松的将数据规整化。这节主要对panda...

深入解析Python中函数的参数与作用域

传递参数 函数传递参数时的一些简要的关键点: 参数的传递是通过自动将对象赋值给本地变量名来实现的。所有的参数实际上都是通过指针进行传递的,作为参数被传递的对象从来不自动拷贝。...

Python获取当前公网ip并自动断开宽带连接实例代码

今天写了一个获取当前公网ip并且自动断开宽带连接的文件,和大家分享下。 这个文件的具体用途大家懂的,可以尽管拿去用,不过目前只适用于Windows平台,我的Python版本是2.7的,...

对web.py设置favicon.ico的方法详解

本文介绍在web.py中设置favicon.ico的方法: 如果没设置favicon,后台日志是这样的: 127.0.0.1:4133 - - [03/Sep/2015 18:49:...