pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现聊天机器人的示例代码

Python实现聊天机器人的示例代码

一、AIML是什么 AIML全名为Artificial Intelligence Markup Language(人工智能标记语言),是一种创建自然语言软件代理的XML语言,是由Ric...

python中栈的原理及实现方法示例

本文实例讲述了python中栈的原理及实现方法。分享给大家供大家参考,具体如下: 栈(stack),有些地方称为堆栈,是一种容器,可存入数据元素、访问元素、删除元素,它的特点在于只能允许...

python实现监控某个服务 服务崩溃即发送邮件报告

前言:最近我们的升级服务器有点不太稳定,经常崩溃掉。然后客户连接不上,跟我们反馈才知道。所以写这个脚本的目的就是为了比客户提前知道升级服务的运行状况,一旦崩溃掉,就能第一时间登录上去,开...

详解在Python的Django框架中创建模板库的方法

不管是写自定义标签还是过滤器,第一件要做的事是创建模板库(Django能够导入的基本结构)。 创建一个模板库分两步走:     第一,决定模板库应该放在哪个...

Python计算库numpy进行方差/标准方差/样本标准方差/协方差的计算

使用numpy可以做很多事情,在这篇文章中简单介绍一下如何使用numpy进行方差/标准方差/样本标准方差/协方差的计算。 variance: 方差 方差(Variance)是概率论中最基...