Long story short. I am trying to build a Siamese network for audio classification. For 50% possibility, the “dataset.py” will try to find a pair of audios in the same category but with different files (also, different category for another 50% possibility). But when the evaluating start, it will hang after fetching a few batches. The trace could be see:

Traceback (most recent call last):                                                                                                                                                                                                        
  File "/home/robin/song/birdclef/old_train.py", line 395, in <module>                                                
    train(args, train_loader, eval_loader)                                                                                                                                                                                                  
  File "/home/robin/song/birdclef/old_train.py", line 280, in train                                                   
    accuracy = evaluate(args, net, eval_loader)                                                                                                                                                                                             
  File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate                                                 
    sounds1, sounds2, type_ids = next(batch_iterator)                                                                 
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()                                                                                                                                                                                                                
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
    idx, data = self._get_data()                                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data                                                                                                              
    success, data = self._try_get_data()                                                                                                                                                                                                    
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    data = self._data_queue.get(timeout=timeout)                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get                                   
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait                              
    gotit = waiter.acquire(True, timeout)                                                                                                                                                                                                   

As usual, I start with suspection of PyTorch. Is the version of PyTorch too new (2.0) that it includes some flaws? Then I quickly rejected my thoughts: if it’s the problem of PyTorch, why it didn’t meet same situation when not using Siamese network?

Then I found this issue in PyTorch GitHub page. It pointed to the clue: the new code in “dataset.py”. Now I notice the problem in my code:

            arr = self.cat_map[ebird_code]
            pair_wav_name = np.random.choice(arr)
            while pair_wav_name == wav_name:
                pair_wav_name = np.random.choice(arr)
            pair_sound = self.get_sound(pair_wav_name, ebird_code)

If a category only have one file, this loop will continue forever. This is the reason of the hang.

The solution is simple:

            arr = self.cat_map[ebird_code]
            if len(arr) > 1:
                pair_wav_name = np.random.choice(arr)
                while pair_wav_name == wav_name:
                    pair_wav_name = np.random.choice(arr)
                pair_wav_name = wav_name
            pair_sound = self.get_sound(pair_wav_name, ebird_code)