pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com5年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

详解Django的model查询操作与查询性能优化

1 如何 在做ORM查询时 查看SQl的执行情况 (1) 最底层的 django.db.connection 在 django shell 中使用  python manag...

python批量提取word内信息

单位收集了很多word格式的调查表,领导需要收集表单里的信息,我就把所有调查表放一个文件里,写了个python小程序把所需的信息打印出来 #coding:utf-8 impo...

Python实现两个list对应元素相减操作示例

本文实例讲述了Python实现两个list对应元素相减操作。分享给大家供大家参考,具体如下: 两个list的对应元素操作,这里以相减为例: # coding=gbk v1 = [21...

Python去除字符串两端空格的方法

目的   获得一个首尾不含多余空格的字符串 方法 可以使用字符串的以下方法处理: string.lstrip(s[, chars]) Return a copy of the stri...

python设置检查点简单实现代码

说检查点,其实就是对过去历史的记录,可以认为是log.不过这里进行了简化.举例来说,我现在又一段文本.文本里放有一堆堆的链接地址.我现在的任务是下载那些地址中的内容.另外因为网络的问题或...