Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Epochs shortened after resuming mid-epoch with Iterable dataset+StatefulDataloader(persistent_workers=True) #7447

Closed
dhruvdcoder opened this issue Mar 12, 2025 · 5 comments · Fixed by #7451

Comments

@dhruvdcoder
Copy link

dhruvdcoder commented Mar 12, 2025

Describe the bug

When torchdata.stateful_dataloader.StatefulDataloader(persistent_workers=True) the epochs after resuming only iterate through the examples that were left in the epoch when the training was interrupted. For example, in the script below training is interrupted on step 124 (epoch 1) when 3 batches are left. Then after resuming, the rest of epochs (2 and 3) only iterate through these 3 batches.

Steps to reproduce the bug

Run the following script with and with PERSISTENT_WORKERS=true.

# !/usr/bin/env python3
# torch==2.5.1
# datasets==3.3.2
# torchdata>=0.9.0
import datasets
import pprint
from torchdata.stateful_dataloader import StatefulDataLoader

import os

PERSISTENT_WORKERS = (
    os.environ.get("PERSISTENT_WORKERS", "False").lower() == "true"
)

# PERSISTENT_WORKERS = True  # Incorrect resume


# ds = datasets.load_from_disk("dataset").to_iterable_dataset(num_shards=4)
def generator():
    for i in range(128):
        yield {"x": i}


ds = datasets.Dataset.from_generator(
    generator, features=datasets.Features({"x": datasets.Value("int32")})
).to_iterable_dataset(num_shards=4)

dl = StatefulDataLoader(
    ds, batch_size=2, num_workers=2, persistent_workers=PERSISTENT_WORKERS
)
global_step = 0
epoch = 0
ds_state_dict = None
state_dict = None
resumed = False
while True:
    if epoch >= 3:
        break
    if state_dict is not None:
        dl.load_state_dict(state_dict)
        state_dict = None
        ds_state_dict = None
        resumed = True
        print("resumed")
    for i, batch in enumerate(dl):
        print(f"epoch: {epoch}, global_step: {global_step}, batch: {batch}")
        global_step += 1  # consume datapoint
        # simulate error
        if global_step == 124 and not resumed:
            ds_state_dict = ds.state_dict()
            state_dict = dl.state_dict()
            print("checkpoint")
            print("ds_state_dict")
            pprint.pprint(ds_state_dict)
            print("dl_state_dict")
            pprint.pprint(state_dict)
            break

    if state_dict is None:
        ds.set_epoch(epoch)
        epoch += 1

The script checkpoints when there are three batches left in the second epoch. After resuming, only the last three batches are repeated in the rest of the epochs.

If it helps, following are the two state_dicts for the dataloader save at the same step with the two settings. The left one is for PERSISTENT_WORKERS=False

Image

Expected behavior

All the elements in the dataset should be iterated through in the epochs following the one where we resumed. The expected behavior can be seen by setting PERSISTENT_WORKERS=False.

Environment info

torch==2.5.1
datasets==3.3.2
torchdata>=0.9.0

@lhoestq
Copy link
Member

lhoestq commented Mar 13, 2025

Thanks for reporting ! Maybe we should store the epoch in the state_dict, and then when the dataset is iterated on again after setting a new epoch it should restart from scratch instead of resuming ? wdyt ?

@dhruvdcoder
Copy link
Author

But why does this only happen when persistent_workers=True? I would expect it to work correctly even without storing the epoch number in the state_dict of the iterable dataset.

@lhoestq
Copy link
Member

lhoestq commented Mar 14, 2025

I think persistent_workers=False simply ignores the dataset state_dict when it starts a new epoch, that's why the issue doesn't appear in that case

@lhoestq
Copy link
Member

lhoestq commented Mar 14, 2025

I opened #7451 to fix the issue, let me know if it works for you

@lhoestq
Copy link
Member

lhoestq commented Mar 14, 2025

I just released datasets 3.4 that includes the fix :)

PS: in your script you probably want to set the epoch like this, otherwise it's still set to 0 after the first epoch:

    if state_dict is None:
-       ds.set_epoch(epoch)
        epoch += 1
+       ds.set_epoch(epoch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants