pytorch 模型可视化的例子

yipeiwu_com6年前Python基础

如下所示:

一. visualize.py

from graphviz import Digraph
import torch
from torch.autograd import Variable
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

二. 使用步骤

import torch
from torch.autograd import Variable
from models import *
from visualize import make_dot
x = Variable(torch.rand(1, 3, 256, 256))
model = GeneratorUNet()
y = model(x)
g = make_dot(y)
g.view()

三. 效果展示

以上这篇pytorch 模型可视化的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python numpy函数中的linspace创建等差数列详解

python numpy函数中的linspace创建等差数列详解

前言 本文主要给大家介绍的是关于linspace创建等差数列的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧。 numpy.linspace 是用于创建一个由等...

python使用itchat库实现微信机器人(好友聊天、群聊天)

itchat是一个开源的微信个人号接口,可以使用该库进行微信网页版中的所有操作,比如:所有好友、添加好友、拉好友群聊、微信机器人等等。详细用户请看文档介绍,在这里。 本文主要使用该库完成...

python使用pandas实现数据分割实例代码

本文研究的主要是Python编程通过pandas将数据分割成时间跨度相等的数据块的相关内容,具体如下。 先上数据,有如下dataframe格式的数据,列名分别为date、ip,我需要统计...

浅析python 中大括号中括号小括号的区分

python语言最常见的括号有三种,分别是:小括号( )、中括号[ ]和大括号也叫做花括号{ }。其作用也各不相同,分别用来代表不同的python基本内置数据类型。 1.python中的...

Python多进程写入同一文件的方法

Python多进程写入同一文件的方法

最近用python的正则表达式处理了一些文本数据,需要把结果写到文件里面,但是由于文件比较大,所以运行起来花费的时间很长。但是打开任务管理器发现CPU只占用了25%,上网找了一下原因发现...