|
8 | 8 |
|
9 | 9 | def _mp_fn(index):
|
10 | 10 | device = xm.xla_device()
|
| 11 | + world_size = xm.xrt_world_size() |
11 | 12 | if xm.xla_device_hw(device) in ('TPU', 'GPU'):
|
12 |
| - world_size = xm.xrt_world_size() |
| 13 | + # Testing with a single replica group |
13 | 14 | ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
|
14 |
| - result = xm.all_gather(ordinal_tensor) |
| 15 | + result = xm.all_gather(ordinal_tensor, dim=0) |
15 | 16 |
|
16 | 17 | cpu_result = result.cpu()
|
17 | 18 | expected = torch.arange(0, world_size, dtype=torch.float)
|
18 | 19 | if not cpu_result.allclose(expected):
|
19 | 20 | print('xm.all_gather() produced wrong reductions', file=sys.stderr)
|
20 |
| - print('[{}] {}'.format(index, cpu_result), file=sys.stderr) |
| 21 | + print(f'[{index}] {cpu_result}', file=sys.stderr) |
21 | 22 | sys.exit(1)
|
| 23 | + |
| 24 | + # Testing with two replica groups |
| 25 | + if world_size % 2 == 0 and world_size > 1: |
| 26 | + mp_groups = [[n for n in range(world_size) if n % 2 == 0], |
| 27 | + [n for n in range(world_size) if n % 2 == 1]] |
| 28 | + group_size = len(mp_groups[0]) |
| 29 | + replica_id = int(index % 2 == 1) |
| 30 | + |
| 31 | + result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups) |
| 32 | + |
| 33 | + cpu_result = result.cpu() |
| 34 | + expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float) |
| 35 | + if not cpu_result.allclose(expected): |
| 36 | + print('xm.all_gather() produced wrong reductions', file=sys.stderr) |
| 37 | + print(f'[{index}] {cpu_result}', file=sys.stderr) |
| 38 | + sys.exit(1) |
| 39 | + else: |
| 40 | + print( |
| 41 | + f'Failed to create two replica groups with {world_size} replicas', |
| 42 | + file=sys.stderr) |
| 43 | + |
22 | 44 | else:
|
23 |
| - print( |
24 |
| - 'Default device {} is not a TPU or GPU device'.format(device), |
25 |
| - file=sys.stderr) |
| 45 | + print(f'{device} is not a TPU or GPU device', file=sys.stderr) |
26 | 46 |
|
27 | 47 |
|
28 | 48 | if __name__ == '__main__':
|
|
0 commit comments