@@ -42,21 +42,21 @@ torch::Tensor RankAssignment(
42
42
43
43
/* *
44
44
* @brief Given node ids, the ranks they belong, the offsets to separate
45
- * different node types and num_bits indicating the world size is <= 2^num_bits,
46
- * returns node ids sorted w.r.t. the ranks that the given ids belong along with
47
- * the original positions.
45
+ * different node types and world size, returns node ids sorted w.r.t. the ranks
46
+ * that the given ids belong along with their new positions.
48
47
*
49
48
* @param nodes Node id tensor to be mapped to a rank in [0, world_size).
50
49
* @param part_ids Rank tensor the nodes belong to.
51
50
* @param offsets_dev Offsets to separate different node types.
52
51
* @param world_size World size, the total number of cooperating GPUs.
53
52
*
54
- * @return (sorted_nodes, original_positions, rank_offsets, rank_offsets_event),
55
- * where the first one includes sorted nodes, the second contains original
56
- * positions of the sorted nodes and the third contains the offsets of the
57
- * sorted_nodes indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]]
58
- * contains nodes that belongs to the `i`th rank. Before accessing rank_offsets
59
- * on the CPU, `rank_offsets_event.synchronize()` is required.
53
+ * @return (sorted_nodes, new_positions, rank_offsets, rank_offsets_event),
54
+ * where the first one includes sorted nodes, the second contains new positions
55
+ * of the given nodes, so that sorted_nodes[new_positions] == nodes, and the
56
+ * third contains the offsets of the sorted_nodes indicating
57
+ * sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that
58
+ * belongs to the `i`th rank. Before accessing rank_offsets on the CPU,
59
+ * `rank_offsets_event.synchronize()` is required.
60
60
*/
61
61
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>
62
62
RankSortImpl (
@@ -72,11 +72,12 @@ RankSortImpl(
72
72
* @param rank Rank of the current GPU.
73
73
* @param world_size World size, the total number of cooperating GPUs.
74
74
*
75
- * @return vector of (sorted_nodes, original_positions, rank_offsets), where the
76
- * first one includes sorted nodes, the second contains original positions of
77
- * the sorted nodes and the third contains the offsets of the sorted_nodes
78
- * indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes
79
- * that belongs to the `i`th rank.
75
+ * @return vector of (sorted_nodes, new_positions, rank_offsets), where the
76
+ * first one includes sorted nodes, the second contains new positions of the
77
+ * given nodes, so that sorted_nodes[new_positions] == nodes, and the third
78
+ * contains the offsets of the sorted_nodes indicating
79
+ * sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes that
80
+ * belongs to the `i`th rank.
80
81
*/
81
82
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort (
82
83
const std::vector<torch::Tensor>& nodes_list, int64_t rank,
0 commit comments