Tensorflowで作成したグラフのノード名、型を表示する方法
Tensorflowの学習済みモデル、重みデータのcheckpointファイルをProtocolbuf形式(pbファイル)に変換する際にグラフのノード名を指定する必要があったのですが、ノード名を知る方法がわからなかったので調べてみました。
結論から記載すると、以下の記述ですべてのネットワーク構成を可視化できます。
graph = tf.Graph() with graph.as_default(): for op in graph.get_operations(): print(op.type) # type print(op.name) # name print(op.op_def) # protocol buf
関数graph.get_operations()でグラフのOperations一覧を取得します。
各ノードごとに型と名前とprotocol bufを表示していきます。
ネットワークの層数が多いと見ていくのが大変なので、typeがReLUのみのノード名を表示するなど、if文で条件を絞ると確認がしやすくなります。
参考サイト
http://docs.fabo.io/tensorflow/building_graph/tensorflow_graph_part1.html