ハードウェア技術者のスキルアップ日誌

某家電メーカーの技術者がスキルアップのために勉強したことを記録するブログです

Tensorflow ー 複数の学習済みモデルを同時に実行する

GITHUBで公開されているディープラーニングの学習済みモデルを流用し、
USBカメラ映像を入力して推論するというプログラムをいくつか作ってきました。
今まで動かしたプログラムは全てモデル1つだけを動かしていましたが、
複数の学習済みモデルを組み合わせて機能を作れないか試してみました。
動かし方を調べたのでまとめておきます。

 

モデルが一つの場合

以下のような構成でプログラムを書いていました。
USBカメラの映像をWhileループ内で取得するため、ループの外でセッションを作成し、学習済みモデルの復元までを行います。ループ内ではsess.run()のみを行います。

import tensorflow as tf
import cv2

video_capture = cv2.VideoCapture(0)

with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, model_path)

    while True:
        ret, frame = video_capture.read()
        sess.run()

        if cv2.waitKey(10) & 0xFF == ord('q'):
            break
video_capture.release() cv2.destroyAllwindows()

 

 

ネットワークが二つの場合

2つの異なるグラフを作成し、それぞれのグラフのセッションを構築してモデルを復元します。そして、whileループ内でそれぞれのセッションを実行しています。

従来の処理ではwith tf.Session() as sess:の中にwhileループを置いていましたが、
withを使わずにSessionを作成するように少し変更しました。

import tensorflow as tf
import cv2

video_capture = cv2.VideoCapture(0)

with tf.Graph().as_default() as graph_1:
     saver1 = tf.train.Saver()
sess1 = tf.Session(graph = graph_1)
saver1.restore(sess1, model_path1)

with tf.Graph().as_default() as graph_2:
     saver2 = tf.train.Saver()
sess2 = tf.Session(graph = graph_2)
saver2.restore(sess2, model_path2)

     while True:
            ret, frame = video_capture.read()
            sess1.run()
            sess2.run()

            if cv2.waitKey(10) & 0xFF == ord('q'):
                 break

    video_capture.release()
cv2.destroyAllwindows()

 

参考サイト

事前に訓練された複数のTensorflowネットを同時に実行する - コードログ