pytorch打印网络结构的实例
作者:每天都要深度学习 发布时间:2023-11-04 15:15:51
标签:pytorch,打印,网络结构
最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。
(1)安装环境:graphviz
conda install -n pytorch python-graphviz
或:
sudo apt-get install graphviz
或者从官网下载,按此教程。
(2)生成网络结构的代码:
def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '('+(', ').join(['%d' % v for v in size])+')'
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
return dot
(3)打印网络结构:
import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
class CNN(nn.module):
def __init__(self):
******
def forward(self,x):
******
return out
*****************************
def make_dot(): #复制上面的代码
*****************************
if __name__ == '__main__':
net = CNN()
x = Variable(torch.randn(1, 1, 1024,1024))
y = net(x)
g = make_dot(y)
g.view()
params = list(net.parameters())
k = 0
for i in params:
l = 1
print("该层的结构:" + str(list(i.size())))
for j in i.size():
l *= j
print("该层参数和:" + str(l))
k = k + l
print("总参数数量和:" + str(k))
(4)结果展示(例如这是一个resnet block类型的网络):
来源:https://blog.csdn.net/Lucifer_zzq/article/details/80657513


猜你喜欢
- 前言Python中使用SSH需要用到OpenSSH,而OpenSSH依赖于paramiko模块,而paramiko模块又依赖于pycrypt
- 1.本人第一次学python做出来的,当时满满的成就感,当作纪念!!!!!非常简单,复制即可使用代码块import json#把字符串类型的
- 你知道(X)HTML中最多余的标签中是什么吗?在我看来就是这个<a>标签,不错,就是每个网站使用最多的超级链接标签<a&g
- 注:Unicode相关知识的详细介绍请参考UTF-8, UTF-16, UTF-32 & BOM。 对于UTF-8/16/32而言,
- 首先安装解析的第三方包:go get gopkg.in/yaml.v2示例:package main import ( "os&q
- 在 Go 中,数组和切片的功能其实是类似的,都是用来存储一种类型元素的集合。数组是固定长度的,而切片的长度是可以调整的数组(array)我们
- 今天在日常维护一个网站时,发现该网站的留言程序没有经过严格的验证过滤,导致了将近十万条垃圾数据。而其中又不乏重要信息,需要清理数据,以及增加
- Mac下mysql安装配置方法图文教程记录如下使用安装包安装mysql双击pkg文件安装一路向下,记得保存最后弹出框中的密码(它是你mysq
- 概要在自然语言处理(NLP)领域,情感分析及分类是一项十分热门的任务。它的目标是从文本中提取出情感信息和意义,通常分为两类:正向情感和负向情
- 其实网上已经有很多ASP生成htm的文章了,有一种方法是ASP+XML的生成方法,虽然有一种好处就是不用程序写模版就可以直接引用原来的要生成
- 举例为大家介绍如何运用命令行实现MySQL导出导入数据库一、命令行导出数据库1.进入MySQL目录下的bin文件夹:cd MySQL中到bi
- 前言如果你在寻找python工作,那你的面试可能会涉及Python相关的问题。通过对网络资料的收集整理,本文列出了100道python的面试
- 前言在使用PC时与PC交互的主要途径是看屏幕显示、听声音,点击鼠标和敲键盘等等。在自动化办公的趋势下,繁琐的工作可以让程序自动完成。比如自动
- 在整个安装的过程中也遇到了很多的坑,故此做个记录,争取下次不再犯!我的整个基本配置如下:电脑环境如下:win10(64位)+CPU:E5-2
- 前言众所周知在java或php等很多面向对象的语言中, 异常处理是依靠throw、catch来进行的。在go语言中,panic和recove
- 1、replicate_do_db 和 replicate_ignore_db 不要同时出现。容易出现混淆。也是毫无意义的。 Replica
- 问题产生:pycharm→settings→Project interpreter→下载matplotlib包运行代码,出现以下提示:找不到
- 很多DBA目前还停留在Oracle 9i或者10g,究其原因有可能是Oracle 11g的价格问题。本文将为大家讲解在Windows 7下安
- 技巧之一:提高使用Request集合的效率 访问一个ASP集合来提取一个值是费时的、占用计算资源的过程。因为这个操作包含了一系列对相关集合的
- 引言日常开发中,我们经常会使用到group by。亲爱的小伙伴,你是否知道group by的工作原理呢?group by和having有什么