どーも、ぐるたか@guru_takaです。
PyTorchで生成したニューラルネットワークの構造や処理の流れがブラックボックス化しやすいので、可視化したいと思いました。
ググってみると、色んな方法があるとわかったのですが、ライブラリ『torchviz』が1番とっつきやすかったです!
ここでは、PyTorchVizでPyTorchで生成したニューラルネットワークをビジュアライズする方法を紹介します。
参考 szagoruyko/pytorchviz: A small package to create visualizations of PyTorch execution graphsPyTorchVizのインストール
コマンドライン
$ brew install graphviz
$ pip3 install torchviz
PyTorchVizでビジュアライズ化
こちらの記事で紹介されいてるサンプルコードを引用します。
参考 PyTorchのネットワーク構造を可視化できるものを探してみたmsdd’s diary python
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyTorchViz
from torchviz import make_dot
# NN構築用のクラス
class NnModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,5)
self.fc2 = nn.Linear(5,2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x))
return x
# モデル定義
model = NnModel()
# 擬似乱数のシード固定
torch.manual_seed(0)
datas = torch.randn(1,10)
# 予測値
y = model(datas)
# ビジュアライズ化
# 予測値が引数になる!
make_dot(y, params=dict(model.named_parameters()))
すると、こんな図が出てきます!
最後に
tensorboardと比べると、PyTorch内の処理(AddmmBackwardなど)が書かれていて、ビックリしました。処理の流れも追えますね!
ただ、AddmmBackwardやTbackwrdなどは、y=ax+b
をしているんだろう…といった直感的な理解に留めるくらいでも良さそうです。今後、色々と出てきそうなので笑
また何か気づいたことあれば、追記していきます!
コメントを残す