For an experiment of metaformer, I was trying to add CIFAR100 dataset into the training script. Since CIFAR100 is too small, I need to let it repeat mulitple times in one epoch. Therefore I add a new type of dataset:
class RepeatDataset(Dataset): def __init__(self, dataset, repeats): self.dataset = dataset self.repeats = repeats self.length = len(dataset) * repeats def __getitem__(self, idx): return self.dataset[idx % len(self.dataset)] def __len__(self): return self.length
But the training will report error:
Traceback (most recent call last): File "/home/robin/code/metaformer/train.py", line 970, in <module> main() File "/home/robin/code/metaformer/train.py", line 732, in main train_metrics = train_one_epoch( ^^^^^^^^^^^^^^^^ File "/home/robin/code/metaformer/train.py", line 798, in train_one_epoch for batch_idx, (input, target) in enumerate(loader): ^^^^^^^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/timm/data/loader.py", line 131, in __iter__ for next_input, next_target in self.loader: ^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 733, in __next__ data = self._next_data() ^^^^^^^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1515, in _next_data return self._process_data(data, worker_id) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1550, in _process_data data.reraise() File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/_utils.py", line 750, in reraise raise exception AttributeError: Caught AttributeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop data = fetcher.fetch(index) # type: ignore[possibly-undefined] ^^^^^^^^^^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch return self.collate_fn(data) ^^^^^^^^^^^^^^^^^^^^^ File "/home/robin/miniconda3/envs/poolformer/lib/python3.12/site-packages/timm/data/mixup.py", line 305, in __call__ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) ^^^^^^^^^^^^^^^^^ AttributeError: 'Image' object has no attribute 'shape'. Did you mean: 'save'?
It cost me a quite long time to solve it. The key is in the implementation of “timm.data.create_loader”: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/loader.py#L291. In it, it will set “dataset.transform” to a new value, and in “timm.data.dataset” https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/dataset.py#L66-L67, it will check and use this new set “transform”:
... if self.transform is not None: img = self.transform(img) ...
Since the class RepeatDataset is created by myself and it will not handle the “dataset.transform = create_transform()”, it failed when calling the non-existed “transform()”.
The fix comes from ChatGPT and I think it’s not bad:
class RepeatDataset(Dataset): def __init__(self, dataset, repeats): self.dataset = dataset self.repeats = repeats self.length = len(dataset) * repeats @property def transform(self): return self.dataset.transform @transform.setter def transform(self, value): self.dataset.transform = value def __getitem__(self, idx): return self.dataset[idx % len(self.dataset)] def __len__(self): return self.length