Tensorflowの基本の使い方②
前回はTensorflowを使うにあたっての考え方をを記事にしました。
今回はGithubに上がっていたソースコードをベースに、具体的にDeep Learningの学習、推論をするときの基本的な記述方法について書いていきたいと思います。
①ネットワークの記述
まずは使用するDeep Learningのネットワークを構築します。
今回はMNIST(手書き文字認識)のサンプルで使われるネットワークを題材にします。
Githubのコードを見ると関数の中で記述しているケースが多いようです。
def create_model(input):
#畳み込み+max pooling
x_1 = tf.reshape(input, [-1, 28, 28, 1])
k_0 = tf.Variable(tf.truncated_normal([4, 4, 1, 10], mean=0.0, stddev=0.1))
x_2 = tf.nn.conv2d(x_1, k_0, strides=[1, 3, 3, 1], padding='VALID')
x_3 = tf.nn.relu(x_2)
x_4 = tf.nn.max_pool(x_3, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID')
#全結合
x_5 = tf.reshape(x_4, [-1, 160])
w_1 = tf.Variable(tf.zeros([160, 40]))
b_1 = tf.Variable([0.1] * 40)
x_6 = tf.matmul(x_5, w_1) + b_1
x_7 = tf.nn.relu(x_6)
w_2 = tf.Variable(tf.zeros([40, 10]))
b_2 = tf.Variable([0.1] * 10)
x_8 = tf.matmul(x_7, w_2) + b_2
#ソフトマックスで確率表現
output = tf.nn.softmax(x_8)
return output
②学習と学習結果の保存
①で作成した関数を使ってグラフを構築します。
加えて、学習時にはlossとoptimizerを設定し、sess_run()で学習を行います。
学習終了後は重みをckpt形式で保存します。
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
#グラフの構築
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, name='x')
output = model_create(x)
labels = tf.placeholder(tf.float32, name='labels')
#損失関数:交差エントロピー, 最適化:Adam
loss = -tf.reduce_sum(labels * tf.log(output))
optimizer = tf.train.AdamOptimizer().minimize(loss)
#重み保存用
saver = tf.train.Saver()
# 学習の実行
with tf.Session(graph=g) as sess:
sess.run(tf.global_variables_initializer())
for i in range(NUM_TRAIN):
batch = mnist.train.next_batch(BATCH_SIZE)
inout = {x: batch[0], labels: batch[1]}
_, loss_value = sess.run((train, loss), feed_dict=inout))
if i % OUTPUT_BY == 0:
print(loss_value)
#学習結果の保存
saver.save(sess, "model.ckpt")
➂学習結果の読み出しと推論
②で学習した結果を読み出して推論を実行します。
学習後、引き続き推論を行う場合には結果の読み出しは不要です。
学習済みモデルを読み出す前にはモデルを構築しておくことに注意が必要。
#グラフの構築
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, name='x')
output = model_create(x)
#重み読み出し用
saver = tf.train.Saver()
# 推論の実行
with tf.Session(graph=g) as sess:
#学習結果の読み出し
saver.restore(sess, "model.ckpt")
result = sess.run(output, feed_dict={x:test_img})
print(result)
ちなみに、ckpt保存時に生成される.metaファイルにはモデルの構造が含まれており、tf.train.import_meta_graph()で.metaファイルをロードすれば、事前にモデルを構築しておく必要がないようです。
参考サイト:TensorFlow で学習したモデルのグラフを `tf.train.import_meta_graph` でロードする - Qiita
まとめ
Tensorflowを使ったDeep Learningの学習、推論のコードの書き方をまとめました。
基本となる型のようなものができましたので、コードを自作する際にはこれをベースに修正を加えていけばいろいろ応用が利くと思います。
参考サイト
これらのサイトを参考にさせていただきました。ありがとうございました!
https://qiita.com/41semicolon/items/d159662385a6d72ee195
https://tensorflow.classcat.com/2018/04/20/tensorflow-programmers-guide-low-level-intro/
https://www.atmarkit.co.jp/ait/articles/1804/24/news134.html
https://qiita.com/cfiken/items/bcdd7eb945c5c3b2bb5f