ハードウェア技術者のスキルアップ日誌

某家電メーカーの技術者がスキルアップのために勉強したことを記録するブログです

Tensorflow - ckptからpbへの変換方法

PC上で学習したモデルを動かす際にはckptファイルで問題ないですが、iOS, Android上や組み込み機器上で動かしたい場合、pb(protocol buffer)形式のファイルが必要となります。ckptからpbへの変換方法を調べたのでメモしておきたいと思います。

ckpt, pbとは?

変換方法に行く前にckpt, pbのおさらいです。

Tensorflowで重み、ネットワーク構造を保存するデータのファイルが.ckptです。
check pointの略(?)です。

ckptファイルは3種類あります。
  ckpt.meta : モデルの構造を記述 重み情報はない
  ckpt.data : 実際の重みが入ったバイナリ
       ckpt.data-00000-of-00001のようなファイル名となる
  ckpt.index : どのファイルがどのstepのものかを一意に特定するためのバイナリ

 

一方、protocol bufferは元はGoogleが開発したシリアライズフォーマットとのことで、Tensorflowのグラフ(モデル構造と重み)をこのprotocol buffer形式で記述できます。

Protocol Buffers - Wikipedia

 

変換方法

以下の参考サイトでいろいろなやり方が紹介されていますが、私の環境でうまくいったサンプルコードです。

import tensorflow as tf
import models

graph = tf.get_default_graph()

sess = tf.Session()

saver = tf.train.import_meta_graph('test.ckpt.meta')
saver.restore(sess, 'test.ckpt')

tf.train.write_graph(sess.graph_def, '.', 'graph.pb', as_text=False)

まず前提として、学習は完了しており、そのネットワークや重みデータはckpt形式で保存されています。(test.ckpt) このckptファイルを用いてグラフを復元し、pb形式で保存しなおします。

tf.train.import_meta_graphを使って.metaからモデルをロードします。これを使えば新たにモデルのインスタンスを作る必要がありません。
そして、saver.restoreでckptを読み込み、学習済みモデルを復元します。
pb形式への変換にはtf.train.write_graphを使用します。 

こちらのサイトには、通常のモデルでは学習済みのWeightやBiasを保持するためのtf.Variableの変数を持つが、これをpbファイルに保存できないため、 
graph_util.convert_variables_to_constants() を使ってConstに変換する必要があると記載されています。
しかし、私の環境ではこれをやるとうまくいかなかったです。
もし、上記のやり方でNGであればこちらをお試しください。

 

参考サイト

https://codeday.me/jp/qa/20190407/569692.html

https://tyfkda.github.io/blog/2016/09/14/tensorflow-protobuf.html

http://workpiles.com/2016/07/tensorflow-protobuf-dump/

https://gist.github.com/funwarioisii/68ed46d8ccfcbc31a456b7c4166b8d0e