解决Pytorch训练过程中loss不下降的问题

yipeiwu_com5年前Python基础

在使用Pytorch进行神经网络训练时,有时会遇到训练学习率不下降的问题。出现这种问题的可能原因有很多,包括学习率过小,数据没有进行Normalization等。不过除了这些常规的原因,还有一种难以发现的原因:在计算loss时数据维数不匹配。

下面是我的代码:

loss_function = torch.nn.MSE_loss()
optimizer.zero_grad()
output = model(x_train)
loss = loss_function(output, y_train)
loss.backward()
optimizer.step()

要特别注意计算loss时网络输出值output和真实值y_train的维数必须完全匹配,否则训练误差不下降,无法训练。这种错误在训练一维数据时很容易忽略,要十分注意。

以上这篇解决Pytorch训练过程中loss不下降的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python学习笔记之if语句的使用示例

前言 条件语句在实际开发中我们已经使用过几次了,在这里我们需要再次隆重的来介绍一下它,下面话不多说了,来一起看看详细的介绍吧。 if语句 顾名思义,该语句为判断语句,先来一个简单的示例...

python del()函数用法

示例程序如下: >>> a = [-1, 3, 'aa', 85] # 定义一个list>>> a[-1, 3, 'aa', 85]>>...

python实现从文件中读取数据并绘制成 x y 轴图形的方法

python实现从文件中读取数据并绘制成 x y 轴图形的方法

如下所示: import matplotlib.pyplot as plt import numpy as np def readfile(filename): dataLis...

Ubuntu18.04中Python2.7与Python3.6环境切换

Ubuntu18.04中Python2.7与Python3.6环境切换

本文为大家分享了Python2.7与Python3.6环境切换的具体方法,供大家参考,具体内容如下 系统支持为:Ubuntu18.04 系统默认安装:Python2.7 自己安装:Pyt...

Python获取当前公网ip并自动断开宽带连接实例代码

今天写了一个获取当前公网ip并且自动断开宽带连接的文件,和大家分享下。 这个文件的具体用途大家懂的,可以尽管拿去用,不过目前只适用于Windows平台,我的Python版本是2.7的,...