-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Epochs shortened after resuming mid-epoch with Iterable dataset+StatefulDataloader(persistent_workers=True) #7447
Comments
Thanks for reporting ! Maybe we should store the epoch in the state_dict, and then when the dataset is iterated on again after setting a new epoch it should restart from scratch instead of resuming ? wdyt ? |
But why does this only happen when |
I think persistent_workers=False simply ignores the dataset state_dict when it starts a new epoch, that's why the issue doesn't appear in that case |
I opened #7451 to fix the issue, let me know if it works for you |
I just released PS: in your script you probably want to set the epoch like this, otherwise it's still set to 0 after the first epoch: if state_dict is None:
- ds.set_epoch(epoch)
epoch += 1
+ ds.set_epoch(epoch) |
Describe the bug
When
torchdata.stateful_dataloader.StatefulDataloader(persistent_workers=True)
the epochs after resuming only iterate through the examples that were left in the epoch when the training was interrupted. For example, in the script below training is interrupted on step 124 (epoch 1) when 3 batches are left. Then after resuming, the rest of epochs (2 and 3) only iterate through these 3 batches.Steps to reproduce the bug
Run the following script with and with PERSISTENT_WORKERS=true.
The script checkpoints when there are three batches left in the second epoch. After resuming, only the last three batches are repeated in the rest of the epochs.
If it helps, following are the two state_dicts for the dataloader save at the same step with the two settings. The left one is for
PERSISTENT_WORKERS=False
Expected behavior
All the elements in the dataset should be iterated through in the epochs following the one where we resumed. The expected behavior can be seen by setting
PERSISTENT_WORKERS=False
.Environment info
torch==2.5.1
datasets==3.3.2
torchdata>=0.9.0
The text was updated successfully, but these errors were encountered: