【PyTorch】PyTorchVizでニューラルネットワークを可視化する

どーも、ぐるたか@guru_takaです。

PyTorchで生成したニューラルネットワークの構造や処理の流れがブラックボックス化しやすいので、可視化したいと思いました。

ググってみると、色んな方法があるとわかったのですが、ライブラリ『torchviz』が1番とっつきやすかったです!

ここでは、PyTorchVizでPyTorchで生成したニューラルネットワークをビジュアライズする方法を紹介します。

参考 szagoruyko/pytorchviz: A small package to create visualizations of PyTorch execution graphs

PyTorchVizのインストール

コマンドライン
$ brew install graphviz
$ pip3 install torchviz

PyTorchVizでビジュアライズ化

こちらの記事で紹介されいてるサンプルコードを引用します。

参考 PyTorchのネットワーク構造を可視化できるものを探してみたmsdd’s diary
python
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorchViz
from torchviz import make_dot

# NN構築用のクラス
class NnModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10,5)
        self.fc2 = nn.Linear(5,2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x))
        return x

# モデル定義
model = NnModel()

# 擬似乱数のシード固定
torch.manual_seed(0)
datas = torch.randn(1,10)

# 予測値
y = model(datas)

# ビジュアライズ化
# 予測値が引数になる!
make_dot(y, params=dict(model.named_parameters()))

すると、こんな図が出てきます!

最後に

tensorboardと比べると、PyTorch内の処理(AddmmBackwardなど)が書かれていて、ビックリしました。処理の流れも追えますね!

ただ、AddmmBackwardやTbackwrdなどは、y=ax+bをしているんだろう…といった直感的な理解に留めるくらいでも良さそうです。今後、色々と出てきそうなので笑

また何か気づいたことあれば、追記していきます!

コメントを残す