|
| 1 | +import unittest |
| 2 | + |
| 3 | +from functools import partial |
| 4 | + |
| 5 | +import backend as F |
| 6 | +import dgl.graphbolt as gb |
| 7 | +import pytest |
| 8 | +import torch |
| 9 | + |
| 10 | +WORLD_SIZE = 7 |
| 11 | + |
| 12 | +assert_equal = partial(torch.testing.assert_close, rtol=0, atol=0) |
| 13 | + |
| 14 | + |
| 15 | +@unittest.skipIf( |
| 16 | + F._default_context_str != "gpu", |
| 17 | + reason="This test requires an NVIDIA GPU.", |
| 18 | +) |
| 19 | +@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) |
| 20 | +@pytest.mark.parametrize("rank", list(range(WORLD_SIZE))) |
| 21 | +def test_gpu_cached_feature_read_async(dtype, rank): |
| 22 | + nodes_list1 = [ |
| 23 | + torch.randint(0, 11111111, [777], dtype=dtype, device=F.ctx()) |
| 24 | + for i in range(10) |
| 25 | + ] |
| 26 | + nodes_list2 = [nodes.sort()[0] for nodes in nodes_list1] |
| 27 | + |
| 28 | + res1 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) |
| 29 | + res2 = torch.ops.graphbolt.rank_sort(nodes_list2, rank, WORLD_SIZE) |
| 30 | + |
| 31 | + for i, ((nodes1, idx1, offsets1), (nodes2, idx2, offsets2)) in enumerate( |
| 32 | + zip(res1, res2) |
| 33 | + ): |
| 34 | + assert_equal(nodes_list1[i], nodes1[idx1.sort()[1]]) |
| 35 | + assert_equal(nodes_list2[i], nodes2[idx2.sort()[1]]) |
| 36 | + assert_equal(offsets1, offsets2) |
| 37 | + assert offsets1.is_pinned() and offsets2.is_pinned() |
| 38 | + |
| 39 | + res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) |
| 40 | + |
| 41 | + # This function is deterministic. Call with identical arguments and check. |
| 42 | + for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3): |
| 43 | + assert_equal(nodes1, nodes3) |
| 44 | + assert_equal(idx1, idx3) |
| 45 | + assert_equal(offsets1, offsets3) |
| 46 | + |
| 47 | + # The dependency on the rank argument is simply a permutation. |
| 48 | + res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE) |
| 49 | + for (nodes1, idx1, offsets1), (nodes4, idx4, offsets4) in zip(res1, res4): |
| 50 | + off1 = offsets1.tolist() |
| 51 | + off4 = offsets4.tolist() |
| 52 | + for i in range(WORLD_SIZE): |
| 53 | + j = (i - rank + WORLD_SIZE) % WORLD_SIZE |
| 54 | + assert_equal( |
| 55 | + nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]] |
| 56 | + ) |
| 57 | + assert_equal( |
| 58 | + idx1[off1[j] : off1[j + 1]], idx4[off4[i] : off4[i + 1]] |
| 59 | + ) |
0 commit comments