pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

使用django的ORM框架按月统计近一年内的数据方法

如下所示: # 计算时间 time = datetime.datetime.now() - relativedelta(years=1) # 获取近一年数据 one_year_dat...

python集合的创建、添加及删除操作示例

本文实例讲述了python集合的创建、添加及删除操作。分享给大家供大家参考,具体如下: 集合时无序可变的序列,集合中的元素放在{}内,集合中的元素具有唯一性。 集合中只能包含数字、字符串...

Python2.x中str与unicode相关问题的解决方法

Python2.x中str与unicode相关问题的解决方法

python2.x中处理中文,是一件头疼的事情。网上写这方面的文章,测次不齐,而且都会有点错误,所以在这里打算自己总结一篇文章。 我也会在以后学习中,不断的修改此篇博客。 这里假设读者已...

Python字符串拼接的几种方法整理

Python字符串拼接的几种方法整理

Python字符串拼接的几种方法整理 第一种 通过加号(+)的形式 print('第一种方式通过加号形式连接 :' + 'love'+'Python' + '\n') 第二种 通...

解决Pycharm 包已经下载,但是运行代码提示找不到模块的问题

解决Pycharm 包已经下载,但是运行代码提示找不到模块的问题

问题产生: pycharm→settings→Project interpreter→下载matplotlib包 运行代码,出现以下提示:找不到‘matplotlib'模块ModuleN...