@@ -140,6 +140,9 @@ def __init__(
140
140
if cooperative :
141
141
datapipe = datapipe .transform (self ._seeds_cooperative_exchange_1 )
142
142
datapipe = datapipe .buffer ()
143
+ datapipe = datapipe .transform (
144
+ self ._seeds_cooperative_exchange_1_wait_future
145
+ ).buffer ()
143
146
datapipe = datapipe .transform (self ._seeds_cooperative_exchange_2 )
144
147
datapipe = datapipe .buffer ()
145
148
datapipe = datapipe .transform (self ._seeds_cooperative_exchange_3 )
@@ -193,19 +196,32 @@ def _wait_preprocess_future(minibatch, cooperative: bool):
193
196
return minibatch
194
197
195
198
@staticmethod
196
- def _seeds_cooperative_exchange_1 (minibatch , group = None ):
197
- rank = thd .get_rank (group )
198
- world_size = thd .get_world_size (group )
199
+ def _seeds_cooperative_exchange_1 (minibatch ):
200
+ rank = thd .get_rank ()
201
+ world_size = thd .get_world_size ()
199
202
seeds = minibatch ._seed_nodes
200
203
is_homogeneous = not isinstance (seeds , dict )
201
204
if is_homogeneous :
202
205
seeds = {"_N" : seeds }
203
206
if minibatch ._seeds_offsets is None :
204
- seeds_list = list (seeds .values ())
205
- result = torch .ops .graphbolt .rank_sort (seeds_list , rank , world_size )
206
207
assert minibatch .compacted_seeds is None
208
+ minibatch ._rank_sort_future = torch .ops .graphbolt .rank_sort_async (
209
+ list (seeds .values ()), rank , world_size
210
+ )
211
+ return minibatch
212
+
213
+ @staticmethod
214
+ def _seeds_cooperative_exchange_1_wait_future (minibatch ):
215
+ world_size = thd .get_world_size ()
216
+ seeds = minibatch ._seed_nodes
217
+ is_homogeneous = not isinstance (seeds , dict )
218
+ if is_homogeneous :
219
+ seeds = {"_N" : seeds }
220
+ num_ntypes = len (seeds .keys ())
221
+ if minibatch ._seeds_offsets is None :
222
+ result = minibatch ._rank_sort_future .wait ()
223
+ delattr (minibatch , "_rank_sort_future" )
207
224
sorted_seeds , sorted_compacted , sorted_offsets = {}, {}, {}
208
- num_ntypes = len (seeds .keys ())
209
225
for i , (
210
226
seed_type ,
211
227
(typed_sorted_seeds , typed_index , typed_offsets ),
@@ -229,16 +245,15 @@ def _seeds_cooperative_exchange_1(minibatch, group=None):
229
245
minibatch ._counts_future = all_to_all (
230
246
counts_received .split (num_ntypes ),
231
247
counts_sent .split (num_ntypes ),
232
- group = group ,
233
248
async_op = True ,
234
249
)
235
250
minibatch ._counts_sent = counts_sent
236
251
minibatch ._counts_received = counts_received
237
252
return minibatch
238
253
239
254
@staticmethod
240
- def _seeds_cooperative_exchange_2 (minibatch , group = None ):
241
- world_size = thd .get_world_size (group )
255
+ def _seeds_cooperative_exchange_2 (minibatch ):
256
+ world_size = thd .get_world_size ()
242
257
seeds = minibatch ._seed_nodes
243
258
minibatch ._counts_future .wait ()
244
259
delattr (minibatch , "_counts_future" )
@@ -256,7 +271,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
256
271
all_to_all (
257
272
typed_seeds_received .split (typed_counts_received ),
258
273
typed_seeds .split (typed_counts_sent ),
259
- group ,
260
274
)
261
275
seeds_received [ntype ] = typed_seeds_received
262
276
counts_sent [ntype ] = typed_counts_sent
0 commit comments