Skip to content

Commit e4c1908

Browse files
committedMar 12, 2025
Update MarkShardingFunction to be AOTAutograd traceable
This is so that we can use it in `scan` later. This has the side-effect of making the function no longer in-place because PyTorch custom_op blows up if I don't clone the tensor. So it "fixes" #8809 as a side-effect.
1 parent 53b3ab8 commit e4c1908

File tree

2 files changed

+93
-10
lines changed

2 files changed

+93
-10
lines changed
 

‎test/spmd/test_xla_sharding.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,10 @@ def test_mark_sharding_ir(self):
836836

837837
self.assertTrue(torch.allclose(expected, actual.cpu()))
838838

839+
def _check_sharding_annotation(self, tensor, sharding_annotation):
840+
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
841+
self.assertIn(sharding_annotation, hlo)
842+
839843
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
840844
"Multiple devices required for autograd sharding test")
841845
def test_mark_sharding_autograd(self):
@@ -849,9 +853,56 @@ def test_mark_sharding_autograd(self):
849853
t = y.sum()
850854
# Backward pass
851855
t.backward()
852-
hlo = torch_xla._XLAC._get_xla_tensors_hlo([z.grad])
853-
sharding_annotation = 'sharding={devices=[1,%d]' % self.n_devices
854-
self.assertIn(sharding_annotation, hlo)
856+
self._check_sharding_annotation(z.grad,
857+
'sharding={devices=[1,%d]' % self.n_devices)
858+
859+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
860+
"Multiple devices required for autograd sharding test")
861+
def test_mark_sharding_aot_compile(self):
862+
mesh = self._get_mesh((self.n_devices,))
863+
864+
def my_fn(x):
865+
z = torch.sin(x)
866+
y = MarkShardingFunction.apply(z, mesh, (0,))
867+
return y + 42
868+
869+
from functorch.compile import aot_function, make_boxed_func # type: ignore
870+
871+
x = torch.randn(8)
872+
x = x.to('xla').requires_grad_(True)
873+
874+
graphs = []
875+
876+
def get_graph(gm: torch.fx.GraphModule, _):
877+
graphs.append(gm)
878+
return make_boxed_func(gm)
879+
880+
y = aot_function(my_fn, get_graph)(x)
881+
t = y.sum()
882+
t.backward()
883+
torch_xla.sync()
884+
885+
sharding_spec = '{devices=[%d]' % self.n_devices
886+
887+
# Check that the output has sharding.
888+
self.assertIn(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(y))
889+
890+
# Check that the gradient has sharding.
891+
self.assertIsNotNone(x.grad)
892+
self.assertIn(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(x.grad))
893+
894+
# Check that the AOTAutograd captured graphs also each contains a mark_sharding.
895+
fwd, bwd = graphs
896+
897+
inp = torch.randn(8).to('xla').requires_grad_(False)
898+
out, *residuals = fwd(inp)
899+
self._check_sharding_annotation(out,
900+
'sharding={devices=[%d]' % self.n_devices)
901+
902+
tangents = torch.randn(8).to('xla').requires_grad_(False)
903+
out, = bwd(*residuals, tangents)
904+
self._check_sharding_annotation(out,
905+
'sharding={devices=[%d]' % self.n_devices)
855906

856907
def test_sharded_tensor_aliasing(self):
857908
met.clear_all()

‎torch_xla/distributed/spmd/xla_sharding.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -1250,24 +1250,56 @@ class MarkShardingFunction(torch.autograd.Function):
12501250
of the intermediate tensors during backward pass.
12511251
12521252
Usage:
1253-
new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
1253+
1254+
>>> new_tensor = MarkShardingFunction.apply(tensor, mesh, ('axis_1', 'axis_2'))
12541255
12551256
This is required to guide GSPMD sharding propagation better during the
12561257
backward pass as during complicated workloads the compiler can introduce extra
12571258
collectives that can hurt performance.
1259+
1260+
Compared to `mark_sharding`, this version will not in-place shard input tensors.
1261+
Instead it takes in an unsharded tensor and returns a new tensor that is sharded.
1262+
After GSPMD sharding propagation in the compiler, both tensors will become sharded.
1263+
1264+
This version can also be used in AOTAutograd.
12581265
"""
12591266

12601267
@staticmethod
12611268
def forward(ctx, torch_tensor: torch.Tensor, mesh: Mesh,
1262-
partition_spec: Tuple) -> torch.Tensor:
1263-
mark_sharding(torch_tensor, mesh, partition_spec)
1269+
partition_spec) -> torch.Tensor:
1270+
o = _aot_mark_sharding(torch_tensor, str(mesh), str(partition_spec))
12641271
ctx.partition_spec = partition_spec
12651272
ctx.mesh = mesh
1266-
return torch_tensor
1273+
return o
12671274

12681275
@staticmethod
1269-
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
1276+
def backward(ctx, grad_output: torch.Tensor): # type: ignore
12701277
partition_spec = ctx.partition_spec
12711278
mesh = ctx.mesh
1272-
mark_sharding(grad_output, mesh, partition_spec)
1273-
return grad_output, None, None
1279+
o = _aot_mark_sharding(grad_output, str(mesh), str(partition_spec))
1280+
return o, None, None
1281+
1282+
1283+
@torch.library.custom_op("xla::aot_mark_sharding", mutates_args=())
1284+
def _aot_mark_sharding(t: torch.Tensor, mesh: str,
1285+
partition_spec: str) -> torch.Tensor:
1286+
if t is None:
1287+
return None
1288+
1289+
import ast
1290+
1291+
import torch_xla.distributed.spmd as xs
1292+
1293+
the_mesh = xs.Mesh.from_str(mesh)
1294+
assert the_mesh is not None
1295+
partition_spec_eval = ast.literal_eval(partition_spec)
1296+
return xs.mark_sharding(t.clone(), the_mesh,
1297+
partition_spec_eval).global_tensor
1298+
1299+
1300+
@_aot_mark_sharding.register_fake
1301+
def aot_mark_sharding_fake(t: torch.Tensor, mesh: str,
1302+
partition_spec: str) -> torch.Tensor:
1303+
if t is None:
1304+
return None
1305+
return torch.empty_like(t)

0 commit comments

Comments
 (0)
Please sign in to comment.