可视化工具Graphviz
一.安装
Graphviz http://www.graphviz.org/
mac用户建议直接用homebrew来安装,官网上版本比较旧
1.安装homebrew
打开终端复制、粘贴以下命令:
ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
2.安装Graphviz
homebrew安装完毕后运行 brew install graphviz即可
安装完后事例代码:
import torchvision.models as models
import torch
from torchsummary import summary
from torch.autograd import Variable
import torch
from torch.autograd import Variable
from graphviz import Digraph
import os
def make_dot(var, params=None):
if params is not None:
assert all(isinstance(p, Variable) for p in params.values())
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style='filled', shape='box', align='left',
fontsize='12', ranksep='0.1', height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return '(' + (', ').join(['%d' % v for v in size]) + ')'
output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
# note: this used to show .saved_tensors in pytorch0.2, but stopped
# working as it was moved to ATen and Variable-Tensor merged
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
elif var in output_nodes:
dot.node(str(id(var)), str(type(var).__name__), fillcolor='darkolivegreen1')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
if isinstance(var, tuple):
for v in var:
add_nodes(v.grad_fn)
else:
add_nodes(var.grad_fn)
return dot
if __name__=="__main__":
os.environ["PATH"] += os.pathsep + '/Library/Python/2.7/site-packages'
## visual model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.vgg()
model =model.to(device)
x = Variable(torch.randn(1, 3, 224,224))
vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
vis_graph.view()
3、安装python的对应的包 sudo pip install graphviz
用pip安装的Graphviz,但是Graphviz不是一个python tool,你仍然需要安装GraphViz‘s executables. 查阅资料后发现,原来我没有安装GraphViz‘s executables
显示网络结构的工具torchsummary
sudo pip install torchsummary 进行安装