Skip to content

Commit 61ec723

Browse files
committedFeb 19, 2025
Using system generated seed in RandomSampler (#1441)
* add new sampler tests * update seed generation in sampler * run precommit * update seed generation * change variable name * update comment * add seed to tests * run precommit
1 parent 0cd8234 commit 61ec723

File tree

3 files changed

+63
-10
lines changed

3 files changed

+63
-10
lines changed
 

‎test/stateful_dataloader/test_sampler.py

+56-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch.utils.data import Dataset
1515

1616
from torchdata.stateful_dataloader import StatefulDataLoader
17-
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
17+
from torchdata.stateful_dataloader.sampler import RandomSampler, StatefulDistributedSampler
1818

1919

2020
class MockDataset(Dataset):
@@ -34,7 +34,10 @@ def __getitem__(self, idx):
3434
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3535
"fork is not supported. Dying (set die_after_fork=0 to override)",
3636
)
37-
@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223")
37+
@unittest.skipIf(
38+
TEST_WITH_ASAN,
39+
"DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223",
40+
)
3841
class TestDataLoader(TestCase):
3942
def setUp(self):
4043
super().setUp()
@@ -44,7 +47,12 @@ def setUp(self):
4447
def test_initialization_StatefulDistributedSampler(self):
4548

4649
sampler = StatefulDistributedSampler(
47-
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
50+
self.dataset,
51+
num_replicas=10,
52+
rank=0,
53+
shuffle=False,
54+
seed=42,
55+
drop_last=False,
4856
)
4957
self.assertEqual(sampler.dataset, self.dataset)
5058
self.assertEqual(sampler.num_replicas, 10)
@@ -139,7 +147,8 @@ def test_drop_last_effect(self):
139147
)
140148

141149
self.assertTrue(
142-
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
150+
len(indices_with_drop) <= len(indices_without_drop),
151+
"Drop last should result in fewer or equal indices",
143152
)
144153

145154
def test_data_order_with_shuffle(self):
@@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self):
153162
for batch in dataloader:
154163
data_loaded.extend(batch)
155164
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
156-
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
165+
self.assertEqual(
166+
data_loaded,
167+
data_sampled,
168+
"Data loaded by DataLoader should match data sampled by sampler",
169+
)
157170

158171
def test_data_order_without_shuffle(self):
159172
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
@@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self):
167180
for batch in dataloader:
168181
data_loaded.extend(batch)
169182
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
170-
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
171-
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")
183+
self.assertEqual(
184+
data_loaded,
185+
data_sampled,
186+
"Data loaded by DataLoader should match data sampled by sampler",
187+
)
188+
self.assertEqual(
189+
data_loaded,
190+
list(range(100)),
191+
"Data loaded by DataLoader should be in original order",
192+
)
172193

173194
def test_data_distribution_across_replicas(self):
174195
num_replicas = 5
@@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self):
181202
data_loaded.extend([int(x.item()) for x in batch])
182203
all_data.extend(data_loaded)
183204
self.assertEqual(
184-
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
205+
sorted(all_data),
206+
list(range(100)),
207+
"All data points should be covered exactly once across all replicas",
185208
)
186209

210+
def test_seed_replicability(self):
211+
# Test that the same seed will result in the same data order
212+
# We first pick a random number as seed, then use it to initialize two dataloaders
213+
min_seed, max_seed = 0, 1000 # [min_seed, max_seed)
214+
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
215+
torch.manual_seed(seed)
216+
217+
dataloader1 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
218+
results1 = list(dataloader1)
219+
220+
# Repeat the same process with the same seed
221+
torch.manual_seed(seed)
222+
dataloader2 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
223+
results2 = list(dataloader2)
224+
225+
# Repeat the same process with a different seed, making sure that the seed is different
226+
min_seed, max_seed = 1000, 2000 # [min_seed, max_seed)
227+
seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item()
228+
torch.manual_seed(seed)
229+
dataloader3 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True)
230+
results3 = list(dataloader3)
231+
232+
self.assertEqual(results1, results2, "Data should be replicable with same seed")
233+
self.assertNotEqual(results1, results3, "Data should not be replicable with different seed")
234+
187235

188236
if __name__ == "__main__":
189237
run_tests()

‎test/stateful_dataloader/test_state_dict.py

+4
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,8 @@ def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
14581458
)
14591459

14601460
def _run(self, data_size, num_workers, batch_size, shuffle):
1461+
# For reproducibility of testing, fixing the seed
1462+
torch.manual_seed(0)
14611463
dataloader1 = self.get_map_dl(
14621464
data_size=data_size,
14631465
num_workers=num_workers,
@@ -1493,6 +1495,8 @@ def _run(self, data_size, num_workers, batch_size, shuffle):
14931495
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)
14941496

14951497
# now run a second dataloder for 4 epochs and check if the order is same.
1498+
# we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader
1499+
torch.manual_seed(0)
14961500
dataloader2 = self.get_map_dl(
14971501
data_size=data_size,
14981502
num_workers=num_workers,

‎torchdata/stateful_dataloader/sampler.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ def __init__(
8888
self.replacement = replacement
8989
self._num_samples = num_samples
9090
if generator is None:
91-
# Ensure that underlying sampler has something repeatable
91+
# Prevoiusly the random seed was fixed as 1. We then changed it to system generated seed to ensure deterministic randomness.
92+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
9293
generator = torch.Generator()
93-
generator.manual_seed(1)
94+
generator.manual_seed(seed)
9495
self.generator = generator
9596
if not isinstance(self.replacement, bool):
9697
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")

0 commit comments

Comments
 (0)