PyTorch实现更新部分网络,其他不更新

yipeiwu_com5年前Python基础

torch.Tensor.detach()的使用

detach()的官方说明如下:

Returns a new Tensor, detached from the current graph.
The result will never require gradient.

假设有模型A和模型B,我们需要将A的输出作为B的输入,但训练时我们只训练模型B. 那么可以这样做:

input_B = output_A.detach()

它可以使两个计算图的梯度传递断开,从而实现我们所需的功能。

以上这篇PyTorch实现更新部分网络,其他不更新就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python使用paramiko操作linux的方法讲解

paramiko介绍 paramiko是一个基于python编写的、使用ssh协议的模块,跟xshell和xftp功能类似,支持加密与认证,可以上传下载和访问服务器的文件。 可以利用pa...

python根据时间获取周数代码实例

时间 时间和周数 import time import datetime # 获取今天是第几周 print(time.strftime('%W')) # 获取当前是周几(0-6,0...

Python在信息学竞赛中的运用及Python的基本用法(详解)

Python在信息学竞赛中的运用及Python的基本用法(详解)

前言 众所周知,Python是一种非常实用的语言。但是由于其运算时的低效和解释型编译,在信息学竞赛中并不用于完成算法程序。但正如LRJ在《算法竞赛入门经典-训练指南》中所说的一样,如果会...

python删除字符串中指定字符的方法

最近开始学机器学习,学习分析垃圾邮件,其中有一部分是要求去除一段字符中的标点符号,查了一下,网上的大多很复杂例如这样 import re temp = "司法局让我和户 1 5....

Python数据持久化存储实现方法分析

本文实例讲述了Python数据持久化存储实现方法。分享给大家供大家参考,具体如下: 1、pymongo的使用 前三步为创建对象 第一步创建连接对象 conn = pymong...