pytorch 实现打印模型的参数值

yipeiwu_com6年前Python基础

对于简单的网络

例如全连接层Linear

可以使用以下方法打印linear层:

fc = nn.Linear(3, 5)
params = list(fc.named_parameters())
print(params.__len__())
print(params[0])
print(params[1])

输出如下:

由于Linear默认是偏置bias的,所有参数列表的长度是2。第一个存的是全连接矩阵,第二个存的是偏置。

对于稍微复杂的网络

例如MLP

mlp = nn.Sequential(
      nn.Dropout(p=0.3),
      nn.Linear(1024, 256),
      nn.Linear(256, 64),
      nn.Linear(64, 16),
      nn.Linear(16, 1)
    )
params = list(mlp.named_parameters())
print(params.__len__())

print(params[0])
print(params[1])

print(params[2])
print(params[3])

输出:

可以发现,堆叠起来的网络,参数是依次放置的。先是全连接的权重,然后偏置。然后是下一层网络的权重+偏置。依次进行下去。

这里有4层fc,4*2=8.所以一共有8个参数矩阵。

以上这篇pytorch 实现打印模型的参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python 装饰器深入理解

讲 Python 装饰器前,我想先举个例子,虽有点污,但跟装饰器这个话题很贴切。 每个人都有的内裤主要功能是用来遮羞,但是到了冬天它没法为我们防风御寒,咋办?我们想到的一个办法就是把内裤...

python实现的自动发送消息功能详解

python实现的自动发送消息功能详解

本文实例讲述了python实现的自动发送消息功能。分享给大家供大家参考,具体如下: 一个简单的脚本 #-*- coding:utf-8 -*- from __future__ imp...

python实现代码统计程序

本文实例为大家分享了python实现代码统计程序的具体代码,供大家参考,具体内容如下 # encoding="utf-8" """ 统计代码行数 """ import sys i...

Python Pexpect库的简单使用方法

简介 最近需要远程操作一个服务器并执行该服务器上的一个python脚本,查到可以使用Pexpect这个库。记录一下。 什么是Pexpect?Pexpect能够产生子应用程序,并控制他们...

Django中使用locals()函数的技巧

对 current_datetime 的一次赋值操作: def current_datetime(request): now = datetime.datetime.now()...