pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法

yipeiwu_com6年前Python基础

如下所示:

#获取模型权重
for k, v in model_2.state_dict().iteritems():
 print("Layer {}".format(k))
 print(v)

#获取模型权重
for layer in model_2.modules():
 if isinstance(layer, nn.Linear):
  print(layer.weight)
#将一个模型权重载入另一个模型
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
 load = torch.load('/home/huangqk/.torch/models/vgg19-dcbb9e9d.pth')
 load_state = {k: v for k, v in load.items() if k not in ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']}
 model_state = model.state_dict()
 model_state.update(load_state)
 model.load_state_dict(model_state)
return model
# 对特定层注入hook
def hook_layers(model):
 def hook_function(module, inputs, outputs):
  recreate_image(inputs[0])

 print(model.features._modules)
 first_layer = list(model.features._modules.items())[0][1]
 first_layer.register_forward_hook(hook_function) 
#获取层
x = someinput
for l in vgg.features.modules():
 x = l(x)
modulelist = list(vgg.features.modules())
for l in modulelist[:5]:
 x = l(x)
keep = x
for l in modulelist[5:]:
 x = l(x)
# 提取vgg模型的中间层输出
# coding:utf8
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple


class Vgg16(torch.nn.Module):
 def __init__(self):
  super(Vgg16, self).__init__()
  features = list(vgg16(pretrained=True).features)[:23]
  # features的第3,8,15,22层分别是: relu1_2,relu2_2,relu3_3,relu4_3
  self.features = nn.ModuleList(features).eval()

 def forward(self, x):
  results = []
  for ii, model in enumerate(self.features):
   x = model(x)
   if ii in {3, 8, 15, 22}:
    results.append(x)

  vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
  return vgg_outputs(*results)

以上这篇pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python3字符串学习教程

字符串类型是python里面最常见的类型,是不可变类型,支持单引号、双引号、三引号,三引号是一对连续的单引号或者双引号,允许一个字符串跨多行。 字符串连接:前面提到的+操作符可用于字符串...

解决python 读取 log日志的编码问题

解决python 读取 log日志的编码问题

1.我要读取log日志的”执行成功”的个数,log日志编码格式为GBK 2.显示报错,大致意思是说utf-8的代码不能解析log日志 3.后来想想把log日志用GBK编码读出来,写到...

Pandas中resample方法详解

Pandas中的resample,重新采样,是对原样本重新处理的一个方法,是一个对常规时间序列数据重新采样和频率转换的便捷的方法。 方法的格式是: DataFrame.resampl...

python实现PID算法及测试的例子

python实现PID算法及测试的例子

PID算法实现 import time class PID: def __init__(self, P=0.2, I=0.0, D=0.0): self.Kp = P...

解决django中ModelForm多表单组合的问题

解决django中ModelForm多表单组合的问题

django是python语言快速实现web服务的大杀器,其开发效率可以非常的高!但因为秉承了语言的灵活性,django框架又太灵活,以至于想实现任何功能都有种“条条大路通罗马”的感觉。...