Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADAG]Enable NPU (hccl) communication for aDAG #47658

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ray.experimental.channel.torch_tensor_type import (
TorchTensorType,
)
from ray.air._internal.device_manager.npu import NPU_TORCH_PACKAGE_AVAILABLE


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -863,6 +864,87 @@ def test_torch_tensor_exceptions(ray_start_regular):
compiled_dag.teardown()


NPU_DEVICES = "0,1,2,3,4,5,6,7"


@ray.remote(resources={"NPU": 1})
class TorchTensorWorkerNPU:
# NOTE(zhilong): To run NPU test, we need to change
# "from ray.experimental.channel.nccl_group import _NcclGroup"
# to "from ray.experimental.channel.hccl_group import _HcclGroup"
# in "python/ray/experimental/channel/torch_tensor_nccl_channel.py"
# and also disable All GPU device check.

# TODO(zhilong): Refactor the aDAG channel so it support different
# XPUs.

def __init__(self, rank):
import torch # noqa: F401

os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES
import torch_npu

self.rank = rank
torch_npu.npu.set_device(rank)

def send(self, shape, dtype, value: int):
import torch

os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES
import torch_npu

# May need to import twice to keep the context,
# otherwise it will lose the ctx.
# Different from nccl with cupy, NPU channel relies on torch,
# so we need to keep the torch ctx.
# Create and return a tensor filled with 'value' on the current NPU
torch_npu.npu.set_device(self.rank)
tensor = torch.ones(shape, dtype=dtype) * value
return tensor.to(f"npu:{self.rank}")

def recv(self, tensor):
# Verify the tensor is on the correct device and return it as CPU tensor
tensor = tensor.cpu()
return (tensor[0].item(), tensor.shape, tensor.dtype)


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_npu_communication(ray_start_regular):
if not NPU_TORCH_PACKAGE_AVAILABLE:
pytest.skip("This test requires NPUs.")

assert (
sum(node["Resources"].get("NPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 NPUs"

# Initialize actor class with NPU support
actor_cls = TorchTensorWorkerNPU
sender = actor_cls.remote(0)
receiver = actor_cls.remote(1)

shape = (10,)
dtype = torch.float16

# Define the DAG with NPU actors
with InputNode() as inp:
dag = sender.send.bind(shape, dtype, inp)
# Can use with hccl after PR 47845 merged
dag = dag.with_type_hint(
TorchTensorType(shape, dtype, transport="hccl", _direct_return=True)
)
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()

# Test tensor sending and receiving on NPUs
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == (i, shape, dtype)

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
2 changes: 1 addition & 1 deletion python/ray/experimental/channel/gpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@DeveloperAPI
class GPUCommunicator(ABC):
"""
Communicator for a group of aDAG actors on Nvidia GPU.
Communicator for a group of aDAG actors on Nvidia GPU or other XPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably change the class name to a more general one if this is to support other XPUs. This is not yet used externally so backward compatibility is not an issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Next step I prefer to change it to AcceleratorCommunicator or just Communicator for all. Currently, this GPUCommunicator is also called from some top level so I just keep it now.


The aDAG execution leverages this internally to support communication
between actors in the group.
Expand Down
190 changes: 190 additions & 0 deletions python/ray/experimental/channel/hccl_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import logging
import os
from typing import Optional

import torch
import torch.distributed as dist
import torch_npu # The torch_npu for communicate

import ray
from ray.exceptions import RayChannelError
from ray.experimental.channel.gpu_communicator import (
GPUCommunicator,
TorchTensorAllocator,
)

# Set ASCEND_RT_VISIBLE_DEVICES environment variable to ensure all NPUs are visible
# This enables NPU to NPU communication across devices.
# Explaination: Since currently the worker can only see the GPU/NPU asign to
# that worker, the NPU needs to see all NPUs to enable the communication channel.
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

logger = logging.getLogger(__name__)


class _HcclGroup(GPUCommunicator):
"""
Represents an actor's HCCL communicator using NPUs.

This is the default HCCL communicator to be used in Compiled Graphs if a
custom communicator is not provided.

This class is not thread-safe.
"""

def __init__(
self,
world_size: int,
comm_id: int,
rank: int,
actor_handles: list,
cuda_stream: Optional[int],
):
# TODO(zhilong): Change cuda_stream to more general name like "stream".
"""
Initialize an HCCL communicator that can be used to communicate p2p with
other NPU actors.

This method blocks until the same call has been made on all other
actors in the group, with the same arguments for world_size and comm_id.

Args:
world_size: The number of participating actors/devices.
comm_id: A unique communicator ID.
rank: The rank of this actor. If None, then the caller is not a
participant of the HCCL group.
actor_handles: A list of actor handles, in rank order.
cuda_stream: Consistency with GPUCommunicator API. Hccl does not use cuda.
"""
self._world_size: int = world_size
self._comm_id: int = comm_id
self._rank: int = rank
self._actor_handles: list = actor_handles
self._closed: bool = False
# Initialize distributed HCCL communication if rank is provided
if rank is not None:
self._init_dist_hccl(rank, world_size)

def _init_dist_hccl(self, rank, world_size):
"""
Initialize the HCCL communication group on NPUs.

Args:
rank: The rank of the current process.
world_size: The total number of processes participating
in the communication.
"""
# Set environment variables if not already set
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
os.environ["HCCL_WHITELIST_DISABLE"] = os.environ.get(
"HCCL_WHITELIST_DISABLE", "1"
)

torch_npu.npu.set_device(rank) # Set the NPU device according to the rank
self.ctx = dist.init_process_group(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this process_group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha.. This is different from process_group....The ascend torch_npu is a little different when handling the distributed while other parts are the same. https://github.com/Ascend/pytorch/blob/868b6f8e00eb0fb179fe719a81e13d8ec1860873/test/distributed/test_send_recv.py#L25

backend="hccl", world_size=world_size, rank=rank
)

def initialize(self, rank: int) -> None:
pass # No additional initialization needed for HCCL group

def get_actor_handles(self) -> list:
"""
Return the list of actor handles.

Returns:
list: Actor handles in rank order.
"""
return self._actor_handles

def get_rank(self, actor: "ray.actor.ActorHandle") -> int:
"""
Return the given actor's rank in the HCCL communicator.

Args:
actor: The actor handle to look up.

Returns:
int: The rank of the actor.
"""
actor_ids = [a._ray_actor_id for a in self._actor_handles]
try:
rank = actor_ids.index(actor._ray_actor_id)
except ValueError:
raise ValueError("Actor is not in the HCCL group.")
return rank

def get_self_rank(self) -> int:
"""
Return this actor's rank.

Returns:
int: The rank of this actor in the HCCL group.
"""
return self._rank

def get_world_size(self) -> int:
"""
Return the number of ranks in the HCCL communicator.

Returns:
int: The world size of the HCCL group.
"""
return self._world_size

def send(self, tensor: "torch.Tensor", peer_rank: int) -> None:
"""
Send a tensor to a peer using HCCL.

Args:
tensor: The tensor to be sent.
peer_rank: The rank of the peer to send the tensor to.
"""
if self._closed:
raise RayChannelError("HCCL group has been destroyed.")
logger.info(f"start to send to:{peer_rank},self._rank : {self._rank} ")
if self._closed:
raise RuntimeError("HCCL group has been destroyed.")
dist.send(tensor, dst=peer_rank)
logger.info(f"finishe send to dist {peer_rank}")

def recv(
self,
shape: tuple,
dtype: "torch.dtype",
peer_rank: int,
allocator=Optional[TorchTensorAllocator],
) -> "torch.Tensor":
"""
Receive a tensor from a peer using HCCL.

Args:
shape: The shape of the tensor to receive.
dtype: The data type of the tensor.
peer_rank: The rank of the peer to receive the tensor from.
allocator: Optional allocator to allocate memory for the tensor.

Returns:
torch.Tensor: The received tensor.
"""
if self._closed:
raise RuntimeError("HCCL group has been destroyed.")
torch_npu.npu.set_device(f"npu:{self._rank}")
tensor = torch.zeros(*shape, dtype=dtype).to(f"npu:{self._rank}")
dist.recv(tensor, src=peer_rank)
return tensor

def destroy(self) -> None:
"""
Destroy the HCCL group and clean up resources.
"""
if self._closed:
return
self._closed = True
dist.destroy_process_group()
if self._rank is not None:
logger.info(
"Destructing NCCL group on actor: "
f"{ray.get_runtime_context().current_actor}"
)