【PyTorch】 DataLoaderのバッチ取り出し時の挙動について
こんにちは。
今回はPyTorchのDataLoaderがバッチデータを取り出す際の挙動について触れようと思います。
環境
- PyTorch: 0.4.0
DataLoaderが対応する型
DataLoaderはDatasetクラスがサンプルしたデータをバッチサイズ分スタックして出力するというのが基本の動作になります。
しかし、当然ながらデータの種類、各人の実装の方法によってDatasetクラスがサンプルするデータの型もバラバラになります。
そこでDataLoaderクラスではその型のばらつきを吸収するように実装されています。
default_collate()
DataLoaderクラスが実装されているモジュールは、torch.utils.data.dataloader.py になります。
その中を覗くと、"default_collate()" 関数があります。 この関数が型チェックを行い、その型に応じた方法でバッチサイズ分データをスタックするという処理を担っています。
この関数を少しずつ読み解いていきましょう。
def default_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
error_msgにあるようにデフォルトでは、tensor, numpy, number(int, float), dict, list に対応しています。
以下はif文で型に応じた処理を行っています。
まずは、torch.Tensorの場合、
if isinstance(batch[0], torch.Tensor): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) # the total number of elements in the tensor storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) # return torch.stack(batch, 0, out=out) # outにスタック結果を格納
if _use_shared_memory: 以下は無視すると、torch.stack() によってスタックされています。
続いて、numpy.ndarrayの場合
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if re.search('[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg.format(elem.dtype)) return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
後半のスカラーを用いる場合(np.asscalar)がどんな状況かはよくわかりませんが、こちらも基本的にはtorch.stack() によってスタックしています。
numbers, dict, listなど残りはまとめて
# int elif isinstance(batch[0], int_classes): return torch.LongTensor(batch) # float elif isinstance(batch[0], float): return torch.DoubleTensor(batch) # str elif isinstance(batch[0], string_classes): return batch # dict elif isinstance(batch[0], collections.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} # list elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed]
ここで注意が必要なのはdictの挙動です。
Datasetによってdictをサンプルするようにしていた場合は、DataLoaderが出力する結果は同じkeyにスタックされて出力されます。
まとめ
以上をまとめると、
- DataLoaderは入力の型に応じた内部処理によって汎用性の高いバッチ取り出しをしている
- スタック方法は、torch.stack()
- dict入力の場合、同じkeyにスタックした結果を返す