Skip to content

Commit 0145b75

Browse files
author
Dariush Wahdany
committed
Fix BatchMemoryManager length
1 parent 2ccbcc7 commit 0145b75

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

opacus/utils/batch_memory_manager.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
from typing import List
1717

1818
import numpy as np
19+
from torch.utils.data import BatchSampler, DataLoader, Sampler
20+
1921
from opacus.optimizers import DPOptimizer
2022
from opacus.utils.uniform_sampler import (
2123
DistributedUniformWithReplacementSampler,
2224
UniformWithReplacementSampler,
2325
)
24-
from torch.utils.data import BatchSampler, DataLoader, Sampler
2526

2627

2728
class BatchSplittingSampler(Sampler[List[int]]):
@@ -71,13 +72,17 @@ def __iter__(self):
7172
def __len__(self):
7273
if isinstance(self.sampler, BatchSampler):
7374
return int(
74-
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
75+
np.ceil(
76+
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
77+
)
7578
)
7679
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
7780
self.sampler, DistributedUniformWithReplacementSampler
7881
):
7982
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
80-
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
83+
return int(
84+
np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size))
85+
)
8186

8287
return len(self.sampler)
8388

0 commit comments

Comments
 (0)