Skip to content

Commit

Permalink
Fix freezing modules in Ghost Clipping (#729)
Browse files Browse the repository at this point in the history
Summary:

Freezing modules with ghost clipping throws an error as corresponding per-sample norms are (not) calculated. Fix: keep in memory the list of all parameters and checking if corresponding requires_grad is True when calculating norms.

Further, unfreezing modules (with and without ghost clipping) wasn't supported because the hooks aren't present for the corresponding modules. Fix: rewrite `requires_grad_' to add the hook.

Facebook
We initially used a `trainable_parameters(module)` to traverse the list of trainable modules upon norm computation. It was slow because `trainable_parameters(module)` is a generator and it traverses the neural network graph overtime.

We replaced it with a list of trainable parameters fixed during model creation time. This is what lead to issues with freezing modules as this list is not updated.

Fix: Use **all parameters** **list**  -- not a generator, so no traversal happens. Further, we check `requires_grad` when calculating per-sample norm to ascertain whether to compute it or not. This is how this check is done in (non-private) [optimizer](https://github.com/pytorch/pytorch/blob/5725462cd8679dd1dea8a469b1bf2e71f226b664/torch/optim/optimizer.py#L963) to determine which parameters are frozen or not.

Differential Revision: D68656459
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Feb 11, 2025
1 parent 0d186a4 commit 9eb7875
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
14 changes: 14 additions & 0 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ def __init__(
force_functorch=force_functorch,
)

def requires_grad_(self, requires_grad: bool = True) -> nn.Module:
"Rewrite requires_grad_ to add/remove hooks based on requires_grad value"
if requires_grad:
# Attack hook to the module
self.add_hooks(
loss_reduction=self.loss_reduction,
batch_first=self.batch_first,
force_functorch=self.force_functorch,
)
else:
# Remove hooks
self.remove_hooks()
return super().requires_grad_(requires_grad)

def forward(self, *args, **kwargs):
return self._module(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
strict=strict,
force_functorch=force_functorch,
)
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.all_parameters = [p for p in self.parameters()]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping
self._per_sample_gradient_norms = None
Expand All @@ -130,7 +130,12 @@ def get_clipping_coef(self) -> torch.Tensor:
def get_norm_sample(self) -> torch.Tensor:
"""Get per-example gradient norms."""
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
[
param._norm_sample
for param in self.all_parameters
if param.requires_grad
],
dim=0,
).norm(2, dim=0)
self.per_sample_gradient_norms = norm_sample
return norm_sample
Expand Down

0 comments on commit 9eb7875

Please sign in to comment.