Skip to content

Commit

Permalink
Add per-sample gradient norm computation as a functionality (pytorch#724
Browse files Browse the repository at this point in the history
)

Summary:

Per-sample gradient norm is computed for Ghost Clipping, but it can be useful generally. Exposed it as a functionality.


```
...

loss.backward()
per_sample_norms  = model.per_sample_gradient_norms

```

Differential Revision: D68634969
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Jan 28, 2025
1 parent c7d6144 commit ba86c7f
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping
self._per_sample_gradient_norms = None

def get_clipping_coef(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
Expand All @@ -131,6 +132,7 @@ def get_norm_sample(self) -> torch.Tensor:
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
self.per_sample_gradient_norms = norm_sample
return norm_sample

def capture_activations_hook(
Expand Down Expand Up @@ -231,3 +233,16 @@ def capture_backprops_hook(
if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len

@property
def per_sample_gradient_norms(self) -> torch.Tensor:
if self._per_sample_gradient_norms is not None:
return self._per_sample_gradient_norms
else:
raise AttributeError(
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
)

@per_sample_gradient_norms.setter
def per_sample_gradient_norms(self, value):
self._per_sample_gradient_norms = value

0 comments on commit ba86c7f

Please sign in to comment.