Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MarkShardingFunction causes OOM when applied to model parameters #8809

Open
tengyifei opened this issue Mar 8, 2025 · 2 comments
Open

MarkShardingFunction causes OOM when applied to model parameters #8809

tengyifei opened this issue Mar 8, 2025 · 2 comments
Assignees

Comments

@tengyifei
Copy link
Collaborator

When tested in https://github.com/AI-Hypercomputer/torchprime/pull/144/files, if we shard parameters with MarkShardingFunction.apply, that causes Mixtral to OOM. Gradient HLO arrays end up living much longer than needed.

Shard both activations and model parameters with MarkShardingFunction: http://shortn/_vvNPYfxSe3
Shard activation with MarkShardingFunction and shard model parameters with xs.mark_sharding: http://shortn/_6OxaSdjJzQ

Another clue is that if I change MarkShardingFunction to be not in-place, then the OOM goes away:

class MarkShardingFunction(torch.autograd.Function):
  """
  Autograd function to mark_sharding on intermediate tensors and the gradient
  of the intermediate tensors during backward pass.

  Usage:
  new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))

  This is required to guide GSPMD sharding propagation better during the
  backward pass as during complicated workloads the compiler can introduce extra
  collectives that can hurt performance.
  """

  @staticmethod
  def forward(
    ctx, torch_tensor: torch.Tensor, mesh: Mesh, partition_spec: tuple
  ) -> torch.Tensor:
    o = mark_sharding(torch_tensor.clone(), mesh, partition_spec)
    ctx.partition_spec = partition_spec
    ctx.mesh = mesh
    return o.global_tensor

  @staticmethod
  def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
    partition_spec = ctx.partition_spec
    mesh = ctx.mesh
    o = mark_sharding(grad_output.clone(), mesh, partition_spec)
    return o.global_tensor, None, None
@ManfeiBai
Copy link
Collaborator

Hi, @tengyifei, is that ok to assign this ticket to you?

@tengyifei
Copy link
Collaborator Author

Yes

tengyifei added a commit that referenced this issue Mar 12, 2025
This is so that we can use it in `scan` later.

This has the side-effect of making the function no longer in-place
because PyTorch custom_op blows up if I don't clone the tensor. So it
"fixes" #8809 as a side-effect.
tengyifei added a commit that referenced this issue Mar 12, 2025
This is so that we can use it in `scan` later.

This has the side-effect of making the function no longer in-place
because PyTorch custom_op blows up if I don't clone the tensor. So it
"fixes" #8809 as a side-effect.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants