14
14
from torch .utils .data import Dataset
15
15
16
16
from torchdata .stateful_dataloader import StatefulDataLoader
17
- from torchdata .stateful_dataloader .sampler import StatefulDistributedSampler
17
+ from torchdata .stateful_dataloader .sampler import RandomSampler , StatefulDistributedSampler
18
18
19
19
20
20
class MockDataset (Dataset ):
@@ -34,7 +34,10 @@ def __getitem__(self, idx):
34
34
"Fails with TSAN with the following error: starting new threads after multi-threaded "
35
35
"fork is not supported. Dying (set die_after_fork=0 to override)" ,
36
36
)
37
- @unittest .skipIf (TEST_WITH_ASAN , "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223" )
37
+ @unittest .skipIf (
38
+ TEST_WITH_ASAN ,
39
+ "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223" ,
40
+ )
38
41
class TestDataLoader (TestCase ):
39
42
def setUp (self ):
40
43
super ().setUp ()
@@ -44,7 +47,12 @@ def setUp(self):
44
47
def test_initialization_StatefulDistributedSampler (self ):
45
48
46
49
sampler = StatefulDistributedSampler (
47
- self .dataset , num_replicas = 10 , rank = 0 , shuffle = False , seed = 42 , drop_last = False
50
+ self .dataset ,
51
+ num_replicas = 10 ,
52
+ rank = 0 ,
53
+ shuffle = False ,
54
+ seed = 42 ,
55
+ drop_last = False ,
48
56
)
49
57
self .assertEqual (sampler .dataset , self .dataset )
50
58
self .assertEqual (sampler .num_replicas , 10 )
@@ -139,7 +147,8 @@ def test_drop_last_effect(self):
139
147
)
140
148
141
149
self .assertTrue (
142
- len (indices_with_drop ) <= len (indices_without_drop ), "Drop last should result in fewer or equal indices"
150
+ len (indices_with_drop ) <= len (indices_without_drop ),
151
+ "Drop last should result in fewer or equal indices" ,
143
152
)
144
153
145
154
def test_data_order_with_shuffle (self ):
@@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self):
153
162
for batch in dataloader :
154
163
data_loaded .extend (batch )
155
164
self .assertEqual (len (data_loaded ), len (self .dataset ), "All data should be loaded" )
156
- self .assertEqual (data_loaded , data_sampled , "Data loaded by DataLoader should match data sampled by sampler" )
165
+ self .assertEqual (
166
+ data_loaded ,
167
+ data_sampled ,
168
+ "Data loaded by DataLoader should match data sampled by sampler" ,
169
+ )
157
170
158
171
def test_data_order_without_shuffle (self ):
159
172
sampler = StatefulDistributedSampler (self .dataset , num_replicas = 1 , rank = 0 , shuffle = False )
@@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self):
167
180
for batch in dataloader :
168
181
data_loaded .extend (batch )
169
182
self .assertEqual (len (data_loaded ), len (self .dataset ), "All data should be loaded" )
170
- self .assertEqual (data_loaded , data_sampled , "Data loaded by DataLoader should match data sampled by sampler" )
171
- self .assertEqual (data_loaded , list (range (100 )), "Data loaded by DataLoader should be in original order" )
183
+ self .assertEqual (
184
+ data_loaded ,
185
+ data_sampled ,
186
+ "Data loaded by DataLoader should match data sampled by sampler" ,
187
+ )
188
+ self .assertEqual (
189
+ data_loaded ,
190
+ list (range (100 )),
191
+ "Data loaded by DataLoader should be in original order" ,
192
+ )
172
193
173
194
def test_data_distribution_across_replicas (self ):
174
195
num_replicas = 5
@@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self):
181
202
data_loaded .extend ([int (x .item ()) for x in batch ])
182
203
all_data .extend (data_loaded )
183
204
self .assertEqual (
184
- sorted (all_data ), list (range (100 )), "All data points should be covered exactly once across all replicas"
205
+ sorted (all_data ),
206
+ list (range (100 )),
207
+ "All data points should be covered exactly once across all replicas" ,
185
208
)
186
209
210
+ def test_seed_replicability (self ):
211
+ # Test that the same seed will result in the same data order
212
+ # We first pick a random number as seed, then use it to initialize two dataloaders
213
+ min_seed , max_seed = 0 , 1000 # [min_seed, max_seed)
214
+ seed = torch .randint (min_seed , max_seed , (1 ,), dtype = torch .int64 ).item ()
215
+ torch .manual_seed (seed )
216
+
217
+ dataloader1 = StatefulDataLoader (self .dataset , batch_size = 1 , shuffle = True )
218
+ results1 = list (dataloader1 )
219
+
220
+ # Repeat the same process with the same seed
221
+ torch .manual_seed (seed )
222
+ dataloader2 = StatefulDataLoader (self .dataset , batch_size = 1 , shuffle = True )
223
+ results2 = list (dataloader2 )
224
+
225
+ # Repeat the same process with a different seed, making sure that the seed is different
226
+ min_seed , max_seed = 1000 , 2000 # [min_seed, max_seed)
227
+ seed = torch .randint (min_seed , max_seed , (1 ,), dtype = torch .int64 ).item ()
228
+ torch .manual_seed (seed )
229
+ dataloader3 = StatefulDataLoader (self .dataset , batch_size = 1 , shuffle = True )
230
+ results3 = list (dataloader3 )
231
+
232
+ self .assertEqual (results1 , results2 , "Data should be replicable with same seed" )
233
+ self .assertNotEqual (results1 , results3 , "Data should not be replicable with different seed" )
234
+
187
235
188
236
if __name__ == "__main__" :
189
237
run_tests ()
0 commit comments