0.5から始める機械学習

Machine Learning, Deep Learning, Computer Vision に関する備忘録

【PyTorch】 DataLoaderのバッチ取り出し時の挙動について

こんにちは。

今回はPyTorchのDataLoaderがバッチデータを取り出す際の挙動について触れようと思います。

環境

  • PyTorch: 0.4.0

DataLoaderが対応する型

DataLoaderはDatasetクラスがサンプルしたデータをバッチサイズ分スタックして出力するというのが基本の動作になります。

しかし、当然ながらデータの種類、各人の実装の方法によってDatasetクラスがサンプルするデータの型もバラバラになります。

そこでDataLoaderクラスではその型のばらつきを吸収するように実装されています。

default_collate()

DataLoaderクラスが実装されているモジュールは、torch.utils.data.dataloader.py になります。

その中を覗くと、"default_collate()" 関数があります。 この関数が型チェックを行い、その型に応じた方法でバッチサイズ分データをスタックするという処理を担っています。

この関数を少しずつ読み解いていきましょう。

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"

error_msgにあるようにデフォルトでは、tensor, numpy, number(int, float), dict, list に対応しています。

以下はif文で型に応じた処理を行っています。

まずは、torch.Tensorの場合、

if isinstance(batch[0], torch.Tensor):
   out = None
    if _use_shared_memory:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum([x.numel() for x in batch])  # the total number of elements in the tensor
        storage = batch[0].storage()._new_shared(numel)
        out = batch[0].new(storage)  # 
    return torch.stack(batch, 0, out=out)  # outにスタック結果を格納

if _use_shared_memory: 以下は無視すると、torch.stack() によってスタックされています。

続いて、numpy.ndarrayの場合

elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
        and elem_type.__name__ != 'string_':
    elem = batch[0]
    if elem_type.__name__ == 'ndarray':
        # array of string classes and object
        if re.search('[SaUO]', elem.dtype.str) is not None:
            raise TypeError(error_msg.format(elem.dtype))

        return torch.stack([torch.from_numpy(b) for b in batch], 0)
    if elem.shape == ():  # scalars
        py_type = float if elem.dtype.name.startswith('float') else int
        return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))

後半のスカラーを用いる場合(np.asscalar)がどんな状況かはよくわかりませんが、こちらも基本的にはtorch.stack() によってスタックしています。

numbers, dict, listなど残りはまとめて

# int
elif isinstance(batch[0], int_classes):
    return torch.LongTensor(batch)
# float
elif isinstance(batch[0], float):
    return torch.DoubleTensor(batch)
# str
elif isinstance(batch[0], string_classes):
    return batch
# dict
elif isinstance(batch[0], collections.Mapping):
    return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
# list
elif isinstance(batch[0], collections.Sequence):
    transposed = zip(*batch)
    return [default_collate(samples) for samples in transposed]

ここで注意が必要なのはdictの挙動です。

Datasetによってdictをサンプルするようにしていた場合は、DataLoaderが出力する結果は同じkeyにスタックされて出力されます。

まとめ

以上をまとめると、

  • DataLoaderは入力の型に応じた内部処理によって汎用性の高いバッチ取り出しをしている
    • スタック方法は、torch.stack()
  • dict入力の場合、同じkeyにスタックした結果を返す