You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Shard both activations and model parameters with MarkShardingFunction: http://shortn/_vvNPYfxSe3
Shard activation with MarkShardingFunction and shard model parameters with xs.mark_sharding: http://shortn/_6OxaSdjJzQ
Another clue is that if I change MarkShardingFunction to be not in-place, then the OOM goes away:
class MarkShardingFunction(torch.autograd.Function):
"""
Autograd function to mark_sharding on intermediate tensors and the gradient
of the intermediate tensors during backward pass.
Usage:
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
This is required to guide GSPMD sharding propagation better during the
backward pass as during complicated workloads the compiler can introduce extra
collectives that can hurt performance.
"""
@staticmethod
def forward(
ctx, torch_tensor: torch.Tensor, mesh: Mesh, partition_spec: tuple
) -> torch.Tensor:
o = mark_sharding(torch_tensor.clone(), mesh, partition_spec)
ctx.partition_spec = partition_spec
ctx.mesh = mesh
return o.global_tensor
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
partition_spec = ctx.partition_spec
mesh = ctx.mesh
o = mark_sharding(grad_output.clone(), mesh, partition_spec)
return o.global_tensor, None, None
The text was updated successfully, but these errors were encountered:
This is so that we can use it in `scan` later.
This has the side-effect of making the function no longer in-place
because PyTorch custom_op blows up if I don't clone the tensor. So it
"fixes" #8809 as a side-effect.
This is so that we can use it in `scan` later.
This has the side-effect of making the function no longer in-place
because PyTorch custom_op blows up if I don't clone the tensor. So it
"fixes" #8809 as a side-effect.
When tested in https://github.com/AI-Hypercomputer/torchprime/pull/144/files, if we shard parameters with
MarkShardingFunction.apply
, that causes Mixtral to OOM. Gradient HLO arrays end up living much longer than needed.Shard both activations and model parameters with
MarkShardingFunction
: http://shortn/_vvNPYfxSe3Shard activation with
MarkShardingFunction
and shard model parameters withxs.mark_sharding
: http://shortn/_6OxaSdjJzQAnother clue is that if I change
MarkShardingFunction
to be not in-place, then the OOM goes away:The text was updated successfully, but these errors were encountered: