pytorch 实现打印模型的参数值

yipeiwu_com6年前Python基础

对于简单的网络

例如全连接层Linear

可以使用以下方法打印linear层:

fc = nn.Linear(3, 5)
params = list(fc.named_parameters())
print(params.__len__())
print(params[0])
print(params[1])

输出如下:

由于Linear默认是偏置bias的,所有参数列表的长度是2。第一个存的是全连接矩阵,第二个存的是偏置。

对于稍微复杂的网络

例如MLP

mlp = nn.Sequential(
      nn.Dropout(p=0.3),
      nn.Linear(1024, 256),
      nn.Linear(256, 64),
      nn.Linear(64, 16),
      nn.Linear(16, 1)
    )
params = list(mlp.named_parameters())
print(params.__len__())

print(params[0])
print(params[1])

print(params[2])
print(params[3])

输出:

可以发现,堆叠起来的网络,参数是依次放置的。先是全连接的权重,然后偏置。然后是下一层网络的权重+偏置。依次进行下去。

这里有4层fc,4*2=8.所以一共有8个参数矩阵。

以上这篇pytorch 实现打印模型的参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

一文带你了解Python中的字符串是什么

一文带你了解Python中的字符串是什么

在《 详解Python拼接字符串的七种方式 》这篇文章里,我提到过,字符串是程序员离不开的事情。后来,我看到了一个英文版本的说法: There are few guarantees in...

Python Requests安装与简单运用

requests是python的一个HTTP客户端库,跟urllib,urllib2类似,那为什么要用requests而不用urllib2呢?官方文档中是这样说明的: python的标...

详解Python3 对象组合zip()和回退方式*zip

详解Python3 对象组合zip()和回退方式*zip

zip即将多个可迭代对象组合为一个可迭代的对象,每次组合时都取出对应顺序的对象元素组合为元组,直到最少的对象中元素全部被组合,剩余的其他对象中未被组合的元素将被舍弃。 keys =...

python 使用sys.stdin和fileinput读入标准输入的方法

1、使用sys.stdin 读取标准输入 [root@c6-ansible-20 script]# cat demo02.py #! /usr/bin/env python fro...

Python多线程编程简单介绍

创建线程 格式如下 复制代码 代码如下: threading.Thread(group=None, target=None, name=None, args=(), kwargs={})...