DETR(End-to-End Object Detection with Transformers)を動かしてみた
先日、AIに関するニュース記事を見ていて、物体検出の新しい手法が発表されたと、大々的に報じられていたものがあったので、非常に気になっていました。
DETR(End-to-End Object Detection with Transformers)というタイトルで、NMSなどの人手による設計をなくし、End-to-Endで物体検出を実現できる手法とのことです。
今回は久しぶりにこちらのモデルを動かしてみたいと思います。
論文 : https://arxiv.org/pdf/2005.12872.pdf
GitHub : https://github.com/facebookresearch/detr
環境
モデルを動かした環境は以下の通りです。
・OS Ubuntu16.04
・GPU GeForce 1080Ti
・CUDA 10.2
・CuDNN 7.6.5
・Pytorch 1.6.0
・torchvision 0.7.0
※AnaconaでPython 3.6の仮想環境を構築
実装
こちらにColab上で動作するサンプルコードが公開されています。
https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb#scrollTo=_GQzINI-FBWp
これだけで動作させることは出来ますが、カメラの映像や動画に対して推論を行うようにコードを修正していきます。
①ライブラリのインストール
必要なライブラリをインストールします。バウンディングボックス描画とフレームレート計算用にcv2とtimeライブラリを作者のコードから追加でインストールします。
今回使用する学習済みモデルはCOCOデータセットで学習したものなので、検出結果を表記する際にクラスを表示できるようにCOCOのクラス名がリストで定義されています。
import math from PIL import Image import requests import matplotlib.pyplot as plt #%config InlineBackend.figure_format = 'retina' import ipywidgets as widgets from IPython.display import display, clear_output import torch from torch import nn from torchvision.models import resnet50 import torchvision.transforms as T torch.set_grad_enabled(False) import cv2 #追加 import time #追加
# COCO classes CLASSES = [ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ]
②使用する関数の定義
サンプルコードでも使用されている関数をそのまま流用していますが、バウンディングボックスを画像に描画する関数(put_rect)のみ新たに作成しています。
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
#以下関数を修正
def put_rect(cv2_img, prob, boxes):
colors = COLORS * 100
output_image = cv2_img
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
xmin = (int)(xmin)
ymin = (int)(ymin)
xmax = (int)(xmax)
ymax = (int)(ymax)
c[0],c[2]=c[2],c[0]
c = tuple([(int)(n*255) for n in c])
output_image = cv2.rectangle(output_image,(xmin,ymin),(xmax,ymax),(0,0,255), 4)
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
output_image = cv2.rectangle(output_image,(xmin,ymin-20),(xmin+len(text)*10,ymin),(0,255,255),-1)
output_image = cv2.putText(output_image,text,(xmin,ymin-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),2)
return output_image
➂モデルのロードと動画の読み込み
学習済みモデルはpytorchのライブラリ経由でダウンロードできます。今回はresnet50ベースのものを使いますが、他にもいくつか学習済みモデルが公開されているので、もし興味があればGitHubの方もご参照ください。
https://github.com/facebookresearch/detr
動画を読み込むのにOpenCVの関数cv2.VideoCapture()を使用します。カメラ入力か、動画入力かで引数を変えます。
#model load model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True) model.eval() #model = model.cuda() #GPUを使用する場合はこちら video_capture = cv2.VideoCapture(0) #USBカメラ入力 #video_capture = cv2.VideoCapture("video_path") #動画読み込み
# 幅と高さを取得 width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) size = (width, height) #総フレーム数とフレームレートを取得 frame_count = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) #動画読み込みの場合 frame_rate = int(video_capture.get(cv2.CAP_PROP_FPS)) fmt = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') out = cv2.VideoWriter('./result.mp4', fmt, frame_rate, size)
seconds = 0.0
fps = 0.0
➃メインのwhileループ/forループ
こちらのwhile/forループで動画のフレームを取得し、推論を実行。その結果を入力画像にバウンディングボックスとして貼り付け、表示するということを繰り返します。合わせて、処理速度を計算し、printで表示します。
while True: # USBカメラ入力の場合
#for i in range(frame_count): # 動画読み込みの場合
start = time.time() _, frame = video_capture.read() frame_cvt = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) im =Image.fromarray(frame_cvt) # mean-std normalize the input image (batch-size: 1) #img = transform(im).unsqueeze(0).cuda() #GPUを使用する場合 img = transform(im).unsqueeze(0) #CPUしかなければこちら # propagate through the model with torch.no_grad(): outputs = model(img) # keep only predictions with 0.7+ confidence probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] keep = probas.max(-1).values > 0.9 # convert boxes from [0; 1] to image scales #bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep].cpu(), im.size) #GPU bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) #CPU #display output_image = put_rect(frame, probas[keep], bboxes_scaled) # End time end = time.time() # Time elapsed seconds = (end - start) print("time:{:.3f} msec".format(seconds*1000) ) # Calculate frames per second fps = ( fps + (1/seconds) ) / 2 cv2.putText(output_image,'{:.2f}'.format(fps)+' fps',(10,50),cv2.FONT_HERSHEY_SIMPLEX,0.8,(255,0,0),3) cv2.imshow('Video', output_image)
out.write(output_image) # Press Q to stop! if cv2.waitKey(1) & 0xFF == ord('q'): break
video_capture.release()
out.release()
cv2.destroyAllWindows()
実行結果
サンプル動画に対して推論を行った結果がこちらです。
DETR(End-to-End Object Detection with Transformers)を動かしてみた
人物、車を正確に検知できているようです。処理時間はVGA解像度の入力で1フレーム70-80ms程度でした。
動画はこちらのフリー素材を使用させていただきました。
https://pixabay.com/ja/videos/%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A8%E3%83%BC%E3%82%AF%E5%B8%82-%E3%83%9E%E3%83%B3%E3%83%8F%E3%83%83%E3%82%BF%E3%83%B3-%E4%BA%BA-1044/
まとめ
革新的な物体検知の手法とされているDETR(End-to-End Object Detection with Transformers)を実際に動かしてみました。今回は試してみただけで、SSDなど従来手法との比較ができていないため、もう少し従来手法も勉強して性能面でどのような変化があるのかまで踏み込むことができたらと思っています。