Skip to content

Commit 5ae6400

Browse files
authored
[GraphBolt][CUDA] rank_sort_async for Cooperative Minibatching. (#7805)
1 parent 31ad9b5 commit 5ae6400

File tree

5 files changed

+52
-12
lines changed

5 files changed

+52
-12
lines changed

graphbolt/src/cuda/cooperative_minibatching_utils.cu

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cub/cub.cuh>
2626
#include <cuda/functional>
2727

28+
#include "../utils.h"
2829
#include "./common.h"
2930
#include "./cooperative_minibatching_utils.cuh"
3031
#include "./cooperative_minibatching_utils.h"
@@ -144,5 +145,15 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
144145
return results;
145146
}
146147

148+
c10::intrusive_ptr<Future<
149+
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
150+
RankSortAsync(
151+
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
152+
const int64_t world_size) {
153+
return async(
154+
[=] { return RankSort(nodes_list, rank, world_size); },
155+
utils::is_on_gpu(nodes_list.at(0)));
156+
}
157+
147158
} // namespace cuda
148159
} // namespace graphbolt

graphbolt/src/cuda/cooperative_minibatching_utils.h

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_
2323

2424
#include <ATen/cuda/CUDAEvent.h>
25+
#include <graphbolt/async.h>
2526
#include <torch/script.h>
2627

2728
namespace graphbolt {
@@ -83,6 +84,12 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
8384
const std::vector<torch::Tensor>& nodes_list, int64_t rank,
8485
int64_t world_size);
8586

87+
c10::intrusive_ptr<Future<
88+
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
89+
RankSortAsync(
90+
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
91+
const int64_t world_size);
92+
8693
} // namespace cuda
8794
} // namespace graphbolt
8895

graphbolt/src/python_binding.cc

+8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ TORCH_LIBRARY(graphbolt, m) {
5959
&Future<std::vector<std::tuple<
6060
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>::
6161
Wait);
62+
m.class_<Future<
63+
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>(
64+
"RankSortFuture")
65+
.def(
66+
"wait",
67+
&Future<std::vector<
68+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
6269
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
6370
"GpuGraphCacheQueryFuture")
6471
.def(
@@ -198,6 +205,7 @@ TORCH_LIBRARY(graphbolt, m) {
198205
#ifdef GRAPHBOLT_USE_CUDA
199206
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
200207
m.def("rank_sort", &cuda::RankSort);
208+
m.def("rank_sort_async", &cuda::RankSortAsync);
201209
#endif
202210
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
203211
m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");

python/dgl/graphbolt/subgraph_sampler.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ def __init__(
140140
if cooperative:
141141
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
142142
datapipe = datapipe.buffer()
143+
datapipe = datapipe.transform(
144+
self._seeds_cooperative_exchange_1_wait_future
145+
).buffer()
143146
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
144147
datapipe = datapipe.buffer()
145148
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
@@ -193,19 +196,32 @@ def _wait_preprocess_future(minibatch, cooperative: bool):
193196
return minibatch
194197

195198
@staticmethod
196-
def _seeds_cooperative_exchange_1(minibatch, group=None):
197-
rank = thd.get_rank(group)
198-
world_size = thd.get_world_size(group)
199+
def _seeds_cooperative_exchange_1(minibatch):
200+
rank = thd.get_rank()
201+
world_size = thd.get_world_size()
199202
seeds = minibatch._seed_nodes
200203
is_homogeneous = not isinstance(seeds, dict)
201204
if is_homogeneous:
202205
seeds = {"_N": seeds}
203206
if minibatch._seeds_offsets is None:
204-
seeds_list = list(seeds.values())
205-
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
206207
assert minibatch.compacted_seeds is None
208+
minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async(
209+
list(seeds.values()), rank, world_size
210+
)
211+
return minibatch
212+
213+
@staticmethod
214+
def _seeds_cooperative_exchange_1_wait_future(minibatch):
215+
world_size = thd.get_world_size()
216+
seeds = minibatch._seed_nodes
217+
is_homogeneous = not isinstance(seeds, dict)
218+
if is_homogeneous:
219+
seeds = {"_N": seeds}
220+
num_ntypes = len(seeds.keys())
221+
if minibatch._seeds_offsets is None:
222+
result = minibatch._rank_sort_future.wait()
223+
delattr(minibatch, "_rank_sort_future")
207224
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
208-
num_ntypes = len(seeds.keys())
209225
for i, (
210226
seed_type,
211227
(typed_sorted_seeds, typed_index, typed_offsets),
@@ -229,16 +245,15 @@ def _seeds_cooperative_exchange_1(minibatch, group=None):
229245
minibatch._counts_future = all_to_all(
230246
counts_received.split(num_ntypes),
231247
counts_sent.split(num_ntypes),
232-
group=group,
233248
async_op=True,
234249
)
235250
minibatch._counts_sent = counts_sent
236251
minibatch._counts_received = counts_received
237252
return minibatch
238253

239254
@staticmethod
240-
def _seeds_cooperative_exchange_2(minibatch, group=None):
241-
world_size = thd.get_world_size(group)
255+
def _seeds_cooperative_exchange_2(minibatch):
256+
world_size = thd.get_world_size()
242257
seeds = minibatch._seed_nodes
243258
minibatch._counts_future.wait()
244259
delattr(minibatch, "_counts_future")
@@ -256,7 +271,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
256271
all_to_all(
257272
typed_seeds_received.split(typed_counts_received),
258273
typed_seeds.split(typed_counts_sent),
259-
group,
260274
)
261275
seeds_received[ntype] = typed_seeds_received
262276
counts_sent[ntype] = typed_counts_sent

tests/python/pytorch/graphbolt/test_dataloader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_gpu_sampling_DataLoader(
163163
if enable_feature_fetch:
164164
bufferer_cnt += 1 # feature fetch has 1.
165165
if cooperative:
166-
# _preprocess stage and each sampling layer.
167-
bufferer_cnt += 3
166+
# _preprocess stage.
167+
bufferer_cnt += 4
168168
datapipe_graph = traverse_dps(dataloader)
169169
bufferers = find_dps(
170170
datapipe_graph,

0 commit comments

Comments
 (0)