0.5から始める機械学習

Machine Learning, Deep Learning, Computer Vision に関する備忘録

【PyTorch】 DataLoaderのバッチ取り出し時の挙動について(2)

こんにちは。

先日に引き続き、DataLoaderクラスについてTipsを少し。

nodaki.hatenablog.com

今回はPyTorchのDataLoaderクラスを使用している時に少し引っかかったポイントがあったのでご紹介しようと思います。

結論から言うと、データのshapeがバッチ内で全て同一でないとエラーが起きてしまいます

その理由と対策について少し触れていきます。

エラー原因

DataLoaderを使用し、バッチ入力を取り出そうとした時以下のようなエラーが起きてしまいました。

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 1 and 2 in dimension 1 at /Users/soumith/minicondabuild3/conda-bld/pytorch_1524590658547/work/aten/src/TH/generic/THTensorMath.c:3586

このエラーがどこから発生するかと言うと、torch.stack() 関数から発生します。

このtorch.stack() が使われるタイミングは、先日の記事でも書いた通り、DataLoaderがデータセットからバッチ入力を取り出す際にバッチサイズ分のデータをスタックする時です。

ここでtorch.stack() のドキュメント を見ると、

Concatenates sequence of tensors along a new dimension.

All tensors need to be of the same size.

とあるように、torch.stack() に入力されるTensorは全て同一のshapeである必要があります。

エラーが起こりうるケース

このエラーが起こりうるケースとして、Object detection タスクが考えられます。

なぜかというと、当然ながらターゲットとなるオブジェクト数は画像内に写っているオブジェクトによって変動します。

例えばYOLOの場合、ターゲットとなるバウンディングボックスの情報は[xmin, ymin, xmax, ymax]であり、オブジェクト数がNとすると、(N, 4) のshapeのターゲットとなります。

対策

これの対策として、2つの方法が考えられます。

  • バッチサイズを1にする
  • collate_fnを自作する
バッチサイズを1に

下記に書かれているようにバッチサイズを1にすれば全て同一shapeになるため、上記のエラーは回避できます。

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0 · Issue #89 · marvis/pytorch-yolo2 · GitHub

ただこの方法は本質的な解決にならないのでおすすめできません。

collate_fnを自作する

DataLoaderクラスのインスタンス生成時の引数として、collate_fnというものがあります。

引数の説明を見ると

Parameters: collate_fn (callable, optional) – merges a list of samples to form a mini-batch.

この関数は以下のように、リスト化されたバッチデータを引数として呼びだされます。

samples = collate_fn([dataset[i] for i in batch_indices])

つまり、このcollate_fnを toch.stack() によるエラーが起きないように定義しておけば良いということになります。

例えば以下のようにすることが考えられます。

def my_collate_fn(batch):
    # datasetの出力が
    # [image, target] = dataset[batch_idx]
    # の場合.
    images = []
    targers = []
    for sample in batch:
        image, target = sample
        images.append(image)
        targets.append(targets)
    images = torch.stack(images, dim=0)
    return [images, targets]

まとめ

  • DataLoaderクラスはデフォルトではバッチデータ内のshapeを全て同じにする必要がある
  • collate_fnを自作することでバッチデータ作成時の挙動を制御できる