@@ -836,6 +836,10 @@ def test_mark_sharding_ir(self):
836
836
837
837
self .assertTrue (torch .allclose (expected , actual .cpu ()))
838
838
839
+ def _check_sharding_annotation (self , tensor , sharding_annotation ):
840
+ hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([tensor ])
841
+ self .assertIn (sharding_annotation , hlo )
842
+
839
843
@unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
840
844
"Multiple devices required for autograd sharding test" )
841
845
def test_mark_sharding_autograd (self ):
@@ -849,9 +853,56 @@ def test_mark_sharding_autograd(self):
849
853
t = y .sum ()
850
854
# Backward pass
851
855
t .backward ()
852
- hlo = torch_xla ._XLAC ._get_xla_tensors_hlo ([z .grad ])
853
- sharding_annotation = 'sharding={devices=[1,%d]' % self .n_devices
854
- self .assertIn (sharding_annotation , hlo )
856
+ self ._check_sharding_annotation (z .grad ,
857
+ 'sharding={devices=[1,%d]' % self .n_devices )
858
+
859
+ @unittest .skipUnless (xr .global_runtime_device_count () > 1 ,
860
+ "Multiple devices required for autograd sharding test" )
861
+ def test_mark_sharding_aot_compile (self ):
862
+ mesh = self ._get_mesh ((self .n_devices ,))
863
+
864
+ def my_fn (x ):
865
+ z = torch .sin (x )
866
+ y = MarkShardingFunction .apply (z , mesh , (0 ,))
867
+ return y + 42
868
+
869
+ from functorch .compile import aot_function , make_boxed_func # type: ignore
870
+
871
+ x = torch .randn (8 )
872
+ x = x .to ('xla' ).requires_grad_ (True )
873
+
874
+ graphs = []
875
+
876
+ def get_graph (gm : torch .fx .GraphModule , _ ):
877
+ graphs .append (gm )
878
+ return make_boxed_func (gm )
879
+
880
+ y = aot_function (my_fn , get_graph )(x )
881
+ t = y .sum ()
882
+ t .backward ()
883
+ torch_xla .sync ()
884
+
885
+ sharding_spec = '{devices=[%d]' % self .n_devices
886
+
887
+ # Check that the output has sharding.
888
+ self .assertIn (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (y ))
889
+
890
+ # Check that the gradient has sharding.
891
+ self .assertIsNotNone (x .grad )
892
+ self .assertIn (sharding_spec , torch_xla ._XLAC ._get_xla_sharding_spec (x .grad ))
893
+
894
+ # Check that the AOTAutograd captured graphs also each contains a mark_sharding.
895
+ fwd , bwd = graphs
896
+
897
+ inp = torch .randn (8 ).to ('xla' ).requires_grad_ (False )
898
+ out , * residuals = fwd (inp )
899
+ self ._check_sharding_annotation (out ,
900
+ 'sharding={devices=[%d]' % self .n_devices )
901
+
902
+ tangents = torch .randn (8 ).to ('xla' ).requires_grad_ (False )
903
+ out , = bwd (* residuals , tangents )
904
+ self ._check_sharding_annotation (out ,
905
+ 'sharding={devices=[%d]' % self .n_devices )
855
906
856
907
def test_sharded_tensor_aliasing (self ):
857
908
met .clear_all ()
0 commit comments