Long story short. I am trying to build a Siamese network for audio classification. For 50% possibility, the “dataset.py” will try to find a pair of audios in the same category but with different files (also, different category for another 50% possibility). But when the evaluating start, it will hang after fetching a few batches. The trace could be see:
Traceback (most recent call last): File "/home/robin/song/birdclef/old_train.py", line 395, in <module> train(args, train_loader, eval_loader) File "/home/robin/song/birdclef/old_train.py", line 280, in train accuracy = evaluate(args, net, eval_loader) File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate sounds1, sounds2, type_ids = next(batch_iterator) File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__ data = self._next_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data idx, data = self._get_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data success, data = self._try_get_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data data = self._data_queue.get(timeout=timeout) File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get self.not_empty.wait(remaining) File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait gotit = waiter.acquire(True, timeout) KeyboardInterrupt
As usual, I start with suspection of PyTorch. Is the version of PyTorch too new (2.0) that it includes some flaws? Then I quickly rejected my thoughts: if it’s the problem of PyTorch, why it didn’t meet same situation when not using Siamese network?
Then I found this issue in PyTorch GitHub page. It pointed to the clue: the new code in “dataset.py”. Now I notice the problem in my code:
arr = self.cat_map[ebird_code] pair_wav_name = np.random.choice(arr) while pair_wav_name == wav_name: pair_wav_name = np.random.choice(arr) pair_sound = self.get_sound(pair_wav_name, ebird_code)
If a category only have one file, this loop will continue forever. This is the reason of the hang.
The solution is simple:
arr = self.cat_map[ebird_code] if len(arr) > 1: pair_wav_name = np.random.choice(arr) while pair_wav_name == wav_name: pair_wav_name = np.random.choice(arr) else: pair_wav_name = wav_name pair_sound = self.get_sound(pair_wav_name, ebird_code)