Skip to content

Commit 3b12115

Browse files
authored
Merge xm all_gather patch (#3416)
* Set proper shard_count for all_gather, when replica groups are non-empty. * Update test_mp_all_gather.py
1 parent c91b766 commit 3b12115

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

test/test_mp_all_gather.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,41 @@
88

99
def _mp_fn(index):
1010
device = xm.xla_device()
11+
world_size = xm.xrt_world_size()
1112
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
12-
world_size = xm.xrt_world_size()
13+
# Testing with a single replica group
1314
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
14-
result = xm.all_gather(ordinal_tensor)
15+
result = xm.all_gather(ordinal_tensor, dim=0)
1516

1617
cpu_result = result.cpu()
1718
expected = torch.arange(0, world_size, dtype=torch.float)
1819
if not cpu_result.allclose(expected):
1920
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
20-
print('[{}] {}'.format(index, cpu_result), file=sys.stderr)
21+
print(f'[{index}] {cpu_result}', file=sys.stderr)
2122
sys.exit(1)
23+
24+
# Testing with two replica groups
25+
if world_size % 2 == 0 and world_size > 1:
26+
mp_groups = [[n for n in range(world_size) if n % 2 == 0],
27+
[n for n in range(world_size) if n % 2 == 1]]
28+
group_size = len(mp_groups[0])
29+
replica_id = int(index % 2 == 1)
30+
31+
result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups)
32+
33+
cpu_result = result.cpu()
34+
expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float)
35+
if not cpu_result.allclose(expected):
36+
print('xm.all_gather() produced wrong reductions', file=sys.stderr)
37+
print(f'[{index}] {cpu_result}', file=sys.stderr)
38+
sys.exit(1)
39+
else:
40+
print(
41+
f'Failed to create two replica groups with {world_size} replicas',
42+
file=sys.stderr)
43+
2244
else:
23-
print(
24-
'Default device {} is not a TPU or GPU device'.format(device),
25-
file=sys.stderr)
45+
print(f'{device} is not a TPU or GPU device', file=sys.stderr)
2646

2747

2848
if __name__ == '__main__':

torch_xla/core/xla_model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,13 @@ def all_gather(value, dim=0, groups=None, output=None):
603603
if dim < 0:
604604
dim = value.dim() + dim
605605
token, devctx = _get_all_reduce_token()
606-
shard_count = None if groups else xrt_world_size()
606+
if groups:
607+
shard_count = len(groups[0])
608+
assert all(len(group) == shard_count for group in groups), \
609+
"Replica groups must have the same number of replicas/shards."
610+
else:
611+
# All replicas belong to a single group
612+
shard_count = xrt_world_size()
607613
if output != None:
608614
# Call the out of place version of the all_gather
609615
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,

0 commit comments

Comments
 (0)