Skip to content

Commit 31ad9b5

Browse files
authored
[GraphBolt][CUDA] Fix Cooperative Minibatching bugs. (#7804)
1 parent 3bc8e22 commit 31ad9b5

File tree

5 files changed

+37
-28
lines changed

5 files changed

+37
-28
lines changed

graphbolt/src/cuda/cooperative_minibatching_utils.cu

+9-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
* implementations in CUDA.
2020
*/
2121
#include <graphbolt/cuda_ops.h>
22+
#include <thrust/scatter.h>
2223
#include <thrust/transform.h>
2324

2425
#include <cub/cub.cuh>
@@ -62,8 +63,7 @@ RankSortImpl(
6263
auto part_ids2 = part_ids.clone();
6364
auto part_ids2_sorted = torch::empty_like(part_ids2);
6465
auto nodes_sorted = torch::empty_like(nodes);
65-
auto index = ops::IndptrEdgeIdsImpl(
66-
offsets_dev, nodes.scalar_type(), torch::nullopt, nodes.numel());
66+
auto index = torch::arange(nodes.numel(), nodes.options());
6767
auto index_sorted = torch::empty_like(index);
6868
return AT_DISPATCH_INDEX_TYPES(
6969
nodes.scalar_type(), "RankSortImpl", ([&] {
@@ -100,8 +100,14 @@ RankSortImpl(
100100
index.data_ptr<index_t>(), index_sorted.data_ptr<index_t>(),
101101
nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0,
102102
num_bits);
103+
auto values = ops::IndptrEdgeIdsImpl(
104+
offsets_dev, nodes.scalar_type(), torch::nullopt, nodes.numel());
105+
THRUST_CALL(
106+
scatter, values.data_ptr<index_t>(),
107+
values.data_ptr<index_t>() + values.numel(),
108+
index_sorted.data_ptr<index_t>(), index.data_ptr<index_t>());
103109
return std::make_tuple(
104-
nodes_sorted, index_sorted, offsets, std::move(offsets_event));
110+
nodes_sorted, index, offsets, std::move(offsets_event));
105111
}));
106112
}
107113

graphbolt/src/cuda/cooperative_minibatching_utils.h

+15-14
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,21 @@ torch::Tensor RankAssignment(
4242

4343
/**
4444
* @brief Given node ids, the ranks they belong, the offsets to separate
45-
* different node types and num_bits indicating the world size is <= 2^num_bits,
46-
* returns node ids sorted w.r.t. the ranks that the given ids belong along with
47-
* the original positions.
45+
* different node types and world size, returns node ids sorted w.r.t. the ranks
46+
* that the given ids belong along with their new positions.
4847
*
4948
* @param nodes Node id tensor to be mapped to a rank in [0, world_size).
5049
* @param part_ids Rank tensor the nodes belong to.
5150
* @param offsets_dev Offsets to separate different node types.
5251
* @param world_size World size, the total number of cooperating GPUs.
5352
*
54-
* @return (sorted_nodes, original_positions, rank_offsets, rank_offsets_event),
55-
* where the first one includes sorted nodes, the second contains original
56-
* positions of the sorted nodes and the third contains the offsets of the
57-
* sorted_nodes indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]]
58-
* contains nodes that belongs to the `i`th rank. Before accessing rank_offsets
59-
* on the CPU, `rank_offsets_event.synchronize()` is required.
53+
* @return (sorted_nodes, new_positions, rank_offsets, rank_offsets_event),
54+
* where the first one includes sorted nodes, the second contains new positions
55+
* of the given nodes, so that sorted_nodes[new_positions] == nodes, and the
56+
* third contains the offsets of the sorted_nodes indicating
57+
* sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that
58+
* belongs to the `i`th rank. Before accessing rank_offsets on the CPU,
59+
* `rank_offsets_event.synchronize()` is required.
6060
*/
6161
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>
6262
RankSortImpl(
@@ -72,11 +72,12 @@ RankSortImpl(
7272
* @param rank Rank of the current GPU.
7373
* @param world_size World size, the total number of cooperating GPUs.
7474
*
75-
* @return vector of (sorted_nodes, original_positions, rank_offsets), where the
76-
* first one includes sorted nodes, the second contains original positions of
77-
* the sorted nodes and the third contains the offsets of the sorted_nodes
78-
* indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes
79-
* that belongs to the `i`th rank.
75+
* @return vector of (sorted_nodes, new_positions, rank_offsets), where the
76+
* first one includes sorted nodes, the second contains new positions of the
77+
* given nodes, so that sorted_nodes[new_positions] == nodes, and the third
78+
* contains the offsets of the sorted_nodes indicating
79+
* sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that
80+
* belongs to the `i`th rank.
8081
*/
8182
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
8283
const std::vector<torch::Tensor>& nodes_list, int64_t rank,

graphbolt/src/cuda/extension/unique_and_compact_map.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,12 @@ __global__ void _MapIdsBatched(
9999

100100
auto slot = map.find(key);
101101
auto new_id = slot->second;
102-
if (index) new_id = index[new_id];
103-
mapped_ids[i] = new_id - unique_ids_offsets[batch_index];
102+
if (index) {
103+
new_id = index[new_id];
104+
} else {
105+
new_id -= unique_ids_offsets[batch_index];
106+
}
107+
mapped_ids[i] = new_id;
104108
}
105109

106110
i += stride;

python/dgl/graphbolt/impl/cooperative_conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def backward(
7878
torch.split(typed_grad_output, counts_sent[ntype]),
7979
)
8080
i = out.new_empty(2, out.shape[0], dtype=torch.int64)
81-
i[0] = torch.arange(
81+
i[0] = seed_inverse_ids[ntype] # src
82+
i[1] = torch.arange(
8283
out.shape[0], device=typed_grad_output.device
83-
) # src
84-
i[1] = seed_inverse_ids[ntype] # dst
84+
) # dst
8585
coo = torch.sparse_coo_tensor(
8686
i,
8787
torch.ones(

tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
2020
@pytest.mark.parametrize("rank", list(range(WORLD_SIZE)))
21-
def test_gpu_cached_feature_read_async(dtype, rank):
21+
def test_rank_sort_and_unique_and_compact(dtype, rank):
2222
torch.manual_seed(7)
2323
nodes_list1 = [
2424
torch.randint(0, 2111111111, [777], dtype=dtype, device=F.ctx())
@@ -32,8 +32,8 @@ def test_gpu_cached_feature_read_async(dtype, rank):
3232
for i, ((nodes1, idx1, offsets1), (nodes2, idx2, offsets2)) in enumerate(
3333
zip(res1, res2)
3434
):
35-
assert_equal(nodes_list1[i], nodes1[idx1.sort()[1]])
36-
assert_equal(nodes_list2[i], nodes2[idx2.sort()[1]])
35+
assert_equal(nodes_list1[i], nodes1[idx1])
36+
assert_equal(nodes_list2[i], nodes2[idx2])
3737
assert_equal(offsets1, offsets2)
3838
assert offsets1.is_pinned() and offsets2.is_pinned()
3939

@@ -50,14 +50,12 @@ def test_gpu_cached_feature_read_async(dtype, rank):
5050
for (nodes1, idx1, offsets1), (nodes4, idx4, offsets4) in zip(res1, res4):
5151
off1 = offsets1.tolist()
5252
off4 = offsets4.tolist()
53+
assert_equal(nodes1[idx1], nodes4[idx4])
5354
for i in range(WORLD_SIZE):
5455
j = (i - rank + WORLD_SIZE) % WORLD_SIZE
5556
assert_equal(
5657
nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]]
5758
)
58-
assert_equal(
59-
idx1[off1[j] : off1[j + 1]], idx4[off4[i] : off4[i + 1]]
60-
)
6159

6260
unique, compacted, offsets = gb.unique_and_compact(
6361
nodes_list1[:1], rank, WORLD_SIZE

0 commit comments

Comments
 (0)