Hanging of PyTorch’s data loader
source link: https://donghao.org/2023/05/05/hanging-of-pytorchs-data-loader/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Hanging of PyTorch’s data loader
Long story short. I am trying to build a Siamese network for audio classification. For 50% possibility, the “dataset.py” will try to find a pair of audios in the same category but with different files (also, different category for another 50% possibility). But when the evaluating start, it will hang after fetching a few batches. The trace could be see:
Traceback (most recent call last): File "/home/robin/song/birdclef/old_train.py", line 395, in <module> train(args, train_loader, eval_loader) File "/home/robin/song/birdclef/old_train.py", line 280, in train accuracy = evaluate(args, net, eval_loader) File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate sounds1, sounds2, type_ids = next(batch_iterator) File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__ data = self._next_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data idx, data = self._get_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data success, data = self._try_get_data() File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data data = self._data_queue.get(timeout=timeout) File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get self.not_empty.wait(remaining) File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait gotit = waiter.acquire(True, timeout) KeyboardInterrupt
Traceback (most recent call last):
File "/home/robin/song/birdclef/old_train.py", line 395, in <module>
train(args, train_loader, eval_loader)
File "/home/robin/song/birdclef/old_train.py", line 280, in train
accuracy = evaluate(args, net, eval_loader)
File "/home/robin/song/birdclef/old_train.py", line 91, in evaluate
sounds1, sounds2, type_ids = next(batch_iterator)
File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
data = self._next_data()
File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
idx, data = self._get_data()
File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1285, in _get_data
success, data = self._try_get_data()
File "/home/robin/miniconda3/envs/bird/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/home/robin/miniconda3/envs/bird/lib/python3.10/queue.py", line 180, in get
self.not_empty.wait(remaining)
File "/home/robin/miniconda3/envs/bird/lib/python3.10/threading.py", line 324, in wait
gotit = waiter.acquire(True, timeout)
KeyboardInterrupt
As usual, I start with suspection of PyTorch. Is the version of PyTorch too new (2.0) that it includes some flaws? Then I quickly rejected my thoughts: if it’s the problem of PyTorch, why it didn’t meet same situation when not using Siamese network?
Then I found this issue in PyTorch GitHub page. It pointed to the clue: the new code in “dataset.py”. Now I notice the problem in my code:
arr = self.cat_map[ebird_code] pair_wav_name = np.random.choice(arr) while pair_wav_name == wav_name: pair_wav_name = np.random.choice(arr) pair_sound = self.get_sound(pair_wav_name, ebird_code)
arr = self.cat_map[ebird_code]
pair_wav_name = np.random.choice(arr)
while pair_wav_name == wav_name:
pair_wav_name = np.random.choice(arr)
pair_sound = self.get_sound(pair_wav_name, ebird_code)
If a category only have one file, this loop will continue forever. This is the reason of the hang.
The solution is simple:
arr = self.cat_map[ebird_code] if len(arr) > 1: pair_wav_name = np.random.choice(arr) while pair_wav_name == wav_name: pair_wav_name = np.random.choice(arr) else: pair_wav_name = wav_name pair_sound = self.get_sound(pair_wav_name, ebird_code)
arr = self.cat_map[ebird_code]
if len(arr) > 1:
pair_wav_name = np.random.choice(arr)
while pair_wav_name == wav_name:
pair_wav_name = np.random.choice(arr)
else:
pair_wav_name = wav_name
pair_sound = self.get_sound(pair_wav_name, ebird_code)
Related Posts
- Accelerate the speed of data loading in PyTorch
I got a desktop computer to train deep learning model last week. The GPU is…
- Some tips about PyTorch and Python
1. '()' may mean tuple or nothing. len(("birds")) # the inner '()' means nothing len(("birds",))…
- Data Join in AWS Redshift
In "Amazon Redshift Database Developer Guide", there is an explanation for data join: "HASH JOIN…
May 5, 2023 - 1:25
RobinDong
machine learning
PyTorch
Leave a comment
Leave a Reply Cancel reply
Your email address will not be published. Required fields are marked *
Comment *
Name *
Email *
Website
Save my name, email, and website in this browser for the next time I comment.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK