ハードウェア技術者のスキルアップ日誌

某家電メーカーの技術者がスキルアップのために勉強したことを記録するブログです

【Keras】Triplet lossを使って学習する

以前、顔認識を行うAIモデルである、Facenetを動かしたときにTriplet lossについて少し触れました。

masaeng.hatenablog.com

Facenetは顔の類似度を特徴ベクトルの距離で表すことで、大量の顔画像を使って学習することなく、数枚の顔画像だけで顔照合ができるというものでした。これを可能にしているのがTriplet lossという損失関数です。

顔だけではなく、物体の照合もTriplet lossを使ったら同様にできるのではないかと思い、Triplet lossを使った学習のやり方について調べてみました。試行錯誤した中で、ハマってしまったところもあるので備忘録として残しておきます。

 

Tensorflow Addonを使う

Githubの実装やQiitaの記事などいろいろ見てみたのですが、一番簡単に使えそうだったのがTensorflowで用意されている関数です。

TensorFlow Addons Losses: TripletSemiHardLoss

tfa.losses.TripletSemiHardLoss()

 これをloss関数として渡せばOKです。

 

上記URLのサイトにはデータセットにmnistを使った学習例があります。これを元にモデルをmobilenetに変更し、データセットにcifar100を使って学習をしてみます。

最初に試したコード
import io
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

train_dataset, test_dataset = tfds.load(name="cifar100", split=['train', 'test'], as_supervised=True)

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(32)
test_dataset = test_dataset.batch(32)

# model create
inputs = tf.keras.Input(shape=(None, None, 3))
x = tf.keras.layers.Lambda(lambda img: tf.image.resize(img, (224, 224)))(inputs)
x = tf.keras.layers.Lambda(tf.keras.applications.mobilenet.preprocess_input)(x)

model = tf.keras.applications.MobileNet(input_shape=(224,224,3), alpha=1.0,
                                         depth_multiplier=1, dropout=1e-3,
                                         include_top=False, weights=None,
                                         input_tensor=x, pooling="avg")

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001), 
loss=tfa.losses.TripletSemiHardLoss(distance_metric="L2", margin=0.1))
# Train the network
history = model.fit(train_dataset, epochs=10)

tensorflowのサイトにある事例はmodelの最後にL2正規化の処理があります。今回作成したmobilenetはL2正規化を含んでいないため、TripletSemiHardLossの引数でL2正規化を指定しています。また、Triplet lossのマージンを持たせたいときは引数marginに値を渡します。

このコードを実行すると、学習の途中で以下のようにlossがnanになってしまいます。

Epoch 1/10
  2/Unknown - 1s 302ms/step - loss: 1.0000
  876/Unknown - 258s 295ms/step - loss: nan

 

原因が分からず、苦戦したのですが、Stack Overflowに原因と解決策が書かれていました。
Nan loss in keras with triplet loss - Stack Overflow

本来triplet lossを使う際には、Anchor, Positive(Anchorと同じクラス), Negative(Anchorと違うクラス)となるデータを3つペアで入力させる必要があります。Tensorflowの関数を使うとこれを意識する必要がなく、train_dataset からうまくAnchor, Positive, Negativeに振り分けて入力してくれるようです。ですが、データセットのクラス数に対してバッチサイズが小さいと、ミニバッチの中にPositiveが存在しないケースが発生することが在ります。この時にtriplet lossがnanになるようです。

改善版

cifar100のクラス数が100なので、バッチ数をクラス数より多い128にします。

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(128) #32->128
test_dataset = test_dataset.batch(
128) #32->128

これにより、うまく学習を行うことができました。

Epoch 1/10
 391/391 [==============================] - 273s 699ms/step - loss: 0.1384
Epoch 2/10
391/391 [==============================] - 274s 702ms/step - loss: 0.0848
Epoch 3/10
391/391 [==============================] - 274s 702ms/step - loss: 0.0908
Epoch 4/10
391/391 [==============================] - 274s 701ms/step - loss: 0.0545
Epoch 5/10
391/391 [==============================] - 274s 700ms/step - loss: 0.0378
Epoch 6/10
391/391 [==============================] - 272s 696ms/step - loss: 0.0264
Epoch 7/10
391/391 [==============================] - 274s 700ms/step - loss: 0.0249
Epoch 8/10
391/391 [==============================] - 274s 702ms/step - loss: 0.0281
Epoch 9/10
391/391 [==============================] - 275s 703ms/step - loss: 0.0449
Epoch 10/10
391/391 [==============================] - 273s 699ms/step - loss: 0.0548 

 

まとめ

Triplet lossを使った学習方法についてまとめてきました。高位APIを使うと簡単に実装ができる一方で、問題が起きたときに何が起きているかわかりにくいというデメリットがあります。できるだけ内部処理も理解して、低位APIでも書けるようになっていることが重要かなと思います。

また、今回の学習結果を使って、距離による分類がうまくいくかも機会があれば試してみたいと思います。