【Keras】Triplet lossを使って学習する
以前、顔認識を行うAIモデルである、Facenetを動かしたときにTriplet lossについて少し触れました。
Facenetは顔の類似度を特徴ベクトルの距離で表すことで、大量の顔画像を使って学習することなく、数枚の顔画像だけで顔照合ができるというものでした。これを可能にしているのがTriplet lossという損失関数です。
顔だけではなく、物体の照合もTriplet lossを使ったら同様にできるのではないかと思い、Triplet lossを使った学習のやり方について調べてみました。試行錯誤した中で、ハマってしまったところもあるので備忘録として残しておきます。
Tensorflow Addonを使う
Githubの実装やQiitaの記事などいろいろ見てみたのですが、一番簡単に使えそうだったのがTensorflowで用意されている関数です。
TensorFlow Addons Losses: TripletSemiHardLoss
tfa.losses.TripletSemiHardLoss()
これをloss関数として渡せばOKです。
上記URLのサイトにはデータセットにmnistを使った学習例があります。これを元にモデルをmobilenetに変更し、データセットにcifar100を使って学習をしてみます。
最初に試したコード
import io import numpy as np import tensorflow as tf import tensorflow_addons as tfa import tensorflow_datasets as tfds train_dataset, test_dataset = tfds.load(name="cifar100", split=['train', 'test'], as_supervised=True) # Build your input pipelines train_dataset = train_dataset.shuffle(1024).batch(32) test_dataset = test_dataset.batch(32) # model create inputs = tf.keras.Input(shape=(None, None, 3)) x = tf.keras.layers.Lambda(lambda img: tf.image.resize(img, (224, 224)))(inputs) x = tf.keras.layers.Lambda(tf.keras.applications.mobilenet.preprocess_input)(x) model = tf.keras.applications.MobileNet(input_shape=(224,224,3), alpha=1.0, depth_multiplier=1, dropout=1e-3, include_top=False, weights=None, input_tensor=x, pooling="avg") # Compile the model model.compile( optimizer=tf.keras.optimizers.Adam(0.001),
loss=tfa.losses.TripletSemiHardLoss(distance_metric="L2", margin=0.1))
# Train the network
history = model.fit(train_dataset, epochs=10)
tensorflowのサイトにある事例はmodelの最後にL2正規化の処理があります。今回作成したmobilenetはL2正規化を含んでいないため、TripletSemiHardLossの引数でL2正規化を指定しています。また、Triplet lossのマージンを持たせたいときは引数marginに値を渡します。
このコードを実行すると、学習の途中で以下のようにlossがnanになってしまいます。
Epoch 1/10
2/Unknown - 1s 302ms/step - loss: 1.0000
876/Unknown - 258s 295ms/step - loss: nan
原因が分からず、苦戦したのですが、Stack Overflowに原因と解決策が書かれていました。
Nan loss in keras with triplet loss - Stack Overflow
本来triplet lossを使う際には、Anchor, Positive(Anchorと同じクラス), Negative(Anchorと違うクラス)となるデータを3つペアで入力させる必要があります。Tensorflowの関数を使うとこれを意識する必要がなく、train_dataset からうまくAnchor, Positive, Negativeに振り分けて入力してくれるようです。ですが、データセットのクラス数に対してバッチサイズが小さいと、ミニバッチの中にPositiveが存在しないケースが発生することが在ります。この時にtriplet lossがnanになるようです。
改善版
cifar100のクラス数が100なので、バッチ数をクラス数より多い128にします。
# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(128) #32->128
test_dataset = test_dataset.batch(128) #32->128
これにより、うまく学習を行うことができました。
まとめ
Triplet lossを使った学習方法についてまとめてきました。高位APIを使うと簡単に実装ができる一方で、問題が起きたときに何が起きているかわかりにくいというデメリットがあります。できるだけ内部処理も理解して、低位APIでも書けるようになっていることが重要かなと思います。
また、今回の学習結果を使って、距離による分類がうまくいくかも機会があれば試してみたいと思います。