Skip to content

Commit

Permalink
chore: Fix type hint on collective output node
Browse files Browse the repository at this point in the history
Signed-off-by: Weixin Deng <[email protected]>
  • Loading branch information
dengwxn committed Sep 19, 2024
1 parent 0b92c8a commit 5601fa4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 57 deletions.
1 change: 0 additions & 1 deletion python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.torch_tensor_nccl_channel import _init_nccl_group
from ray.experimental.channel.torch_tensor_type import (
ChannelOutputType,
GPUCommunicator,
TorchTensorType,
)
Expand Down
74 changes: 36 additions & 38 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,39 +833,39 @@ def _preprocess(self) -> None:

self.actor_task_count[actor_handle._actor_id] += 1

if not isinstance(dag_node, CollectiveOutputNode):
# Add all writers to the NCCL group for send and recv methods.
if dag_node.type_hint.requires_nccl():
nccl_actors.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
mixed_nccl_group_error_message = (
"Accelerated DAGs do not support mixed usage of "
"type hints of default NCCL group "
'(i.e., TorchTensor(transport="nccl"))'
"and custom NCCL group "
"(i.e., TorchTensor(transport=nccl_group)). "
"Please check all the TorchTensor type hints and "
"make sure only one type of NCCL transport is specified."
)
if custom_nccl_group is None:
if self._custom_nccl_group is not None:
raise ValueError(mixed_nccl_group_error_message)
self._use_default_nccl_group = True
else:
if self._use_default_nccl_group:
raise ValueError(mixed_nccl_group_error_message)
if self._custom_nccl_group is not None:
if self._custom_nccl_group != custom_nccl_group:
raise ValueError(
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified. Check all the "
"TorchTensor(transport=nccl_group) type hints "
"to make sure only one NCCL group is used."
)
self._custom_nccl_group = custom_nccl_group
elif isinstance(dag_node, CollectiveOutputNode):
# Collect all collective groups.
# Collect actors for NCCL P2P methods.
if dag_node.type_hint.requires_nccl():
nccl_actors.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
mixed_nccl_group_error_message = (
"Accelerated DAGs do not support mixed usage of "
"type hints of default NCCL group "
'(i.e., TorchTensor(transport="nccl"))'
"and custom NCCL group "
"(i.e., TorchTensor(transport=nccl_group)). "
"Please check all the TorchTensor type hints and "
"make sure only one type of NCCL transport is specified."
)
if custom_nccl_group is None:
if self._custom_nccl_group is not None:
raise ValueError(mixed_nccl_group_error_message)
self._use_default_nccl_group = True
else:
if self._use_default_nccl_group:
raise ValueError(mixed_nccl_group_error_message)
if self._custom_nccl_group is not None:
if self._custom_nccl_group != custom_nccl_group:
raise ValueError(
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified. Check all the "
"TorchTensor(transport=nccl_group) type hints "
"to make sure only one NCCL group is used."
)
self._custom_nccl_group = custom_nccl_group

# Collect collective groups for NCCL collective methods.
if isinstance(dag_node, CollectiveOutputNode):
nccl_collective_groups.add(dag_node.collective_group)
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
Expand Down Expand Up @@ -1074,11 +1074,9 @@ def _get_or_compile(
visited.add(cur_idx)

task = self.idx_to_task[cur_idx]
if not isinstance(task.dag_node, CollectiveOutputNode):
# The NCCL group is already initialized for CollectiveOutputNode.
type_hint = task.dag_node.type_hint
if type_hint.requires_nccl():
type_hint.set_nccl_group_id(self._nccl_group_id)
type_hint = task.dag_node.type_hint
if type_hint.requires_nccl():
type_hint.set_nccl_group_id(self._nccl_group_id)

if (
isinstance(task.dag_node, ClassMethodNode)
Expand Down
44 changes: 26 additions & 18 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import total_ordering
from enum import Enum
from typing import Set, Tuple, List, Dict
from typing import Set, Tuple, List, Dict, Optional
import ray
import heapq
from collections import defaultdict
Expand Down Expand Up @@ -135,7 +135,8 @@ def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode
def _select_next_nodes(
actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]],
graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]],
):
) -> Optional[List[_DAGOperationGraphNode]]:
# [TODO] Update comments.
"""
This function selects the next nodes for topological sort to generate execution
schedule. If there are multiple candidate _DAGOperationGraphNodes, select the node
Expand Down Expand Up @@ -175,32 +176,33 @@ def _select_next_nodes(
execution schedules.
"""
top_priority_node = None
next_nodes: List[_DAGOperationGraphNode] = []
for _, candidates in actor_to_candidates.items():
if len(candidates) == 0:
continue
if top_priority_node is None or candidates[0] < top_priority_node:
top_priority_node = candidates[0]
assert top_priority_node is not None
next_nodes.append(
if top_priority_node is None:
return None

next_nodes: List[_DAGOperationGraphNode] = [
heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id])
)
]

if not (
top_priority_node.operation.type == _DAGNodeOperationType.WRITE
and top_priority_node.requires_nccl
):
assert len(next_nodes) == 1
return next_nodes

# An NCCL write node is picked. NCCL is a blocking operation, so we need to pick all
# the corresponding NCCL read nodes to avoid a deadlock.
for downstream_node_metadata in top_priority_node.out_edges:
task_idx, op_type = downstream_node_metadata[0], downstream_node_metadata[1]
downstream_node = graph[task_idx][op_type]
assert downstream_node.operation.type == _DAGNodeOperationType.READ
next_nodes.append(downstream_node)
assert len(next_nodes) == 1 + len(top_priority_node.out_edges)
else:
# An NCCL write node is picked. NCCL is a blocking operation, so we need to
# pick all the corresponding NCCL read nodes to avoid a deadlock.
for downstream_node_metadata in top_priority_node.out_edges:
task_idx, op_type = downstream_node_metadata[0], downstream_node_metadata[1]
downstream_node = graph[task_idx][op_type]
assert downstream_node.operation.type == _DAGNodeOperationType.READ
next_nodes.append(downstream_node)
assert len(next_nodes) == 1 + len(top_priority_node.out_edges)

return next_nodes


Expand Down Expand Up @@ -343,28 +345,34 @@ def _generate_actor_to_execution_schedule(

visited_nodes = set()

# [TODO] Update comments.
# Use topological sort algorithm to generate the execution schedule. Each iteration
# pops a candidate node from `actor_to_candidates` and each DAG node consists of
# three operations: READ, COMPUTE, and WRITE.
for _ in range(len(graph) * 3):
while True:
# The function `_select_next_nodes` will pop a candidate node from
# `actor_to_candidates` and return a list of nodes that can be executed
# in the next step. If multiple nodes are returned, only the NCCL write
# node is popped in this iteration.
nodes = _select_next_nodes(actor_to_candidates, graph)
if nodes is None:
break
for node in nodes:
if node in visited_nodes:
continue
actor_to_execution_schedule[node.actor_handle].append(node.operation)
visited_nodes.add(node)
for node in nodes:
for out_node_task_idx, out_node_type in node.out_edges:
out_node = graph[out_node_task_idx][out_node_type]
out_node.in_edges.remove((node.task_idx, node.operation.type))
if out_node.in_degree == 0:
if out_node.in_degree == 0 and out_node not in visited_nodes:
heapq.heappush(
actor_to_candidates[out_node.actor_handle._actor_id],
out_node,
)
for node in nodes:
assert node.in_degree == 0, f"Expected {node} to have in degree 0"
for _, candidates in actor_to_candidates.items():
assert len(candidates) == 0
return actor_to_execution_schedule

0 comments on commit 5601fa4

Please sign in to comment.