2

Hanging of PyTorch’s data loader

 1 year ago
source link: https://donghao.org/2023/05/05/hanging-of-pytorchs-data-loader/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Hanging of PyTorch’s data loader

Long story short. I am trying to build a Siamese network for audio classification. For 50% possibility, the “dataset.py” will try to find a pair of audios in the same category but with different files (also, different category for another 50% possibility). But when the evaluating start, it will hang after fetching a few batches. The trace could be see:

Traceback (most recent call last):                                                                                                                                                                                                        
  File "/home/robin/song/birdclef/old_train.py", line 395, in <module>                                                
    train(args, train_loader, eval_loader)                                                                                                                                                                                                  
  File "/home/robin/song/birdclef/old_train.py", line 280, in train                                                   
    accuracy = evaluate(args, net, eval_loader)                                                                                                                                                                                             
  File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate                                                 
    sounds1, sounds2, type_ids = next(batch_iterator)                                                                 
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()                                                                                                                                                                                                                
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
    idx, data = self._get_data()                                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data                                                                                                              
    success, data = self._try_get_data()                                                                                                                                                                                                    
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    data = self._data_queue.get(timeout=timeout)                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get                                   
    self.not_empty.wait(remaining)                                                                                    
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait                              
    gotit = waiter.acquire(True, timeout)                                                                                                                                                                                                   
KeyboardInterrupt 
Python
Traceback (most recent call last):                                                                                                                                                                                                        
  File "/home/robin/song/birdclef/old_train.py", line 395, in <module>                                                
    train(args, train_loader, eval_loader)                                                                                                                                                                                                  
  File "/home/robin/song/birdclef/old_train.py", line 280, in train                                                   
    accuracy = evaluate(args, net, eval_loader)                                                                                                                                                                                             
  File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate                                                 
    sounds1, sounds2, type_ids = next(batch_iterator)                                                                 
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()                                                                                                                                                                                                                
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
    idx, data = self._get_data()                                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data                                                                                                              
    success, data = self._try_get_data()                                                                                                                                                                                                    
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    data = self._data_queue.get(timeout=timeout)                                                                      
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get                                   
    self.not_empty.wait(remaining)                                                                                    
  File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait                              
    gotit = waiter.acquire(True, timeout)                                                                                                                                                                                                   
KeyboardInterrupt 

As usual, I start with suspection of PyTorch. Is the version of PyTorch too new (2.0) that it includes some flaws? Then I quickly rejected my thoughts: if it’s the problem of PyTorch, why it didn’t meet same situation when not using Siamese network?

Then I found this issue in PyTorch GitHub page. It pointed to the clue: the new code in “dataset.py”. Now I notice the problem in my code:

            arr = self.cat_map[ebird_code]
            pair_wav_name = np.random.choice(arr)
            while pair_wav_name == wav_name:
                pair_wav_name = np.random.choice(arr)
            pair_sound = self.get_sound(pair_wav_name, ebird_code)
Python
            arr = self.cat_map[ebird_code]
            pair_wav_name = np.random.choice(arr)
            while pair_wav_name == wav_name:
                pair_wav_name = np.random.choice(arr)
            pair_sound = self.get_sound(pair_wav_name, ebird_code)

If a category only have one file, this loop will continue forever. This is the reason of the hang.

The solution is simple:

            arr = self.cat_map[ebird_code]
            if len(arr) > 1:
                pair_wav_name = np.random.choice(arr)
                while pair_wav_name == wav_name:
                    pair_wav_name = np.random.choice(arr)
            else:
                pair_wav_name = wav_name
            pair_sound = self.get_sound(pair_wav_name, ebird_code)
Python
            arr = self.cat_map[ebird_code]
            if len(arr) > 1:
                pair_wav_name = np.random.choice(arr)
                while pair_wav_name == wav_name:
                    pair_wav_name = np.random.choice(arr)
            else:
                pair_wav_name = wav_name
            pair_sound = self.get_sound(pair_wav_name, ebird_code)

Related Posts

May 5, 2023 - 1:25 RobinDong machine learning
PyTorch
Leave a comment

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Comment *

Name *

Email *

Website

Save my name, email, and website in this browser for the next time I comment.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK