File tree 1 file changed +8
-3
lines changed
1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change 16
16
from typing import List
17
17
18
18
import numpy as np
19
+ from torch .utils .data import BatchSampler , DataLoader , Sampler
20
+
19
21
from opacus .optimizers import DPOptimizer
20
22
from opacus .utils .uniform_sampler import (
21
23
DistributedUniformWithReplacementSampler ,
22
24
UniformWithReplacementSampler ,
23
25
)
24
- from torch .utils .data import BatchSampler , DataLoader , Sampler
25
26
26
27
27
28
class BatchSplittingSampler (Sampler [List [int ]]):
@@ -71,13 +72,17 @@ def __iter__(self):
71
72
def __len__ (self ):
72
73
if isinstance (self .sampler , BatchSampler ):
73
74
return int (
74
- len (self .sampler ) * (self .sampler .batch_size / self .max_batch_size )
75
+ np .ceil (
76
+ len (self .sampler ) * (self .sampler .batch_size / self .max_batch_size )
77
+ )
75
78
)
76
79
elif isinstance (self .sampler , UniformWithReplacementSampler ) or isinstance (
77
80
self .sampler , DistributedUniformWithReplacementSampler
78
81
):
79
82
expected_batch_size = self .sampler .sample_rate * self .sampler .num_samples
80
- return int (len (self .sampler ) * (expected_batch_size / self .max_batch_size ))
83
+ return int (
84
+ np .ceil (len (self .sampler ) * (expected_batch_size / self .max_batch_size ))
85
+ )
81
86
82
87
return len (self .sampler )
83
88
You can’t perform that action at this time.
0 commit comments