pytorch 实现模型不同层设置不同的学习率方式

yipeiwu_com6年前Python基础

在目标检测的模型训练中, 我们通常都会有一个特征提取网络backbone, 例如YOLO使用的darknet SSD使用的VGG-16。

为了达到比较好的训练效果, 往往会加载预训练的backbone模型参数, 然后在此基础上训练检测网络, 并对backbone进行微调, 这时候就需要为backbone设置一个较小的lr。

class net(torch.nn.Module):
  def __init__(self):
    super(net, self).__init__()
    # backbone
    self.backbone = ...
    # detect
    self....

在设置optimizer时, 只需要参数分为两个部分, 并分别给定不同的学习率lr。

base_params = list(map(id, net.backbone.parameters()))
logits_params = filter(lambda p: id(p) not in base_params, net.parameters())
params = [
  {"params": logits_params, "lr": config.lr},
  {"params": net.backbone.parameters(), "lr": config.backbone_lr},
]
optimizer = torch.optim.SGD(params, momentum=config.momentum, weight_decay=config.weight_decay)
 

以上这篇pytorch 实现模型不同层设置不同的学习率方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python获取文件ssdeep值的方法

本文实例讲述了Python获取文件ssdeep值的方法,分享给大家供大家参考。具体方法如下: 首先,得到ssdeep值,需要先import ssdeep 在ubuntu上安装pyssde...

Python实现冒泡,插入,选择排序简单实例

本文所述的Python实现冒泡,插入,选择排序简单实例比较适合Python初学者从基础开始学习数据结构和算法,示例简单易懂,具体代码如下: # -*- coding: cp936 -...

Python实现将xml导入至excel

Python实现将xml导入至excel

最近在使用Testlink时,发现导入的用例是xml格式,且没有合适的工具转成excel格式,xml使用excel打开显示的东西也太多,网上也有相关工具转成csv格式的,结果也不合人意。...

Python在OpenCV里实现极坐标变换功能

Python在OpenCV里实现极坐标变换功能

在中学里学习过直角坐标系,也叫做笛卡尔坐标系,它是正交坐标系,不过也学习过极坐标系,这种坐标系比较适合大炮发射的场合。极坐标系的定义如下: 在 平面内取一个定点O, 叫极点,引一条射线O...

Python测试Kafka集群(pykafka)实例

生产者代码: # -* coding:utf8 *- from pykafka import KafkaClient host = 'IP:9092, IP:9092, IP...