From 9a83941b97f08e97b2ec06a92cc3b4151600de81 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Sep 2024 07:09:07 -0700 Subject: [PATCH 01/11] add P2PCommunication IR --- csrc/dispatch.h | 3 +- csrc/multidevice/communication.cpp | 105 +++++++++++++++++++++++++++++ csrc/multidevice/communication.h | 53 +++++++++++++++ 3 files changed, 160 insertions(+), 1 deletion(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 99bdfa70565..0d1da14d529 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -111,7 +111,8 @@ class Val; f(SdpaFwdOp); \ f(SdpaBwdOp); \ f(Communication); \ - f(ForLoop); + f(ForLoop); \ + f(P2PCommunication); #define DISPATCH_FOR_ALL_KIR_EXPRS(f) \ f(Allocate); \ f(Asm); \ diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 91a7085f401..b042f5305dd 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -210,6 +210,53 @@ std::string Communication::toInlineString(int indent_size) const { return toString(indent_size); } +std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type) { + switch (type) { + case P2PCommunicationType::send: + os << "send"; + break; + case P2PCommunicationType::recv: + os << "recv"; + break; + default: + NVF_THROW("unrecognized P2PCommunicationType: ", type); + } + return os; +} + +P2PCommunication::P2PCommunication( + IrBuilderPasskey passkey, + P2PCommunicationType type, + TensorView* buffer, + Val* peer, + Val* tag) + : Expr(passkey) { + if (tag == nullptr) { + tag = passkey.ir_container_->zeroVal(); + } + + addInput(buffer); + addDataAttribute(type); + addAttribute(peer); + addAttribute(tag); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(P2PCommunication) + +std::string P2PCommunication::toString(const int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "P2PCommunication " << name() << " (" + << "type=" << type() << ", " + << " buffer=" << buffer() << ", " + << "peer=" << peer() << ", " + << "tag=" << tag() << ")\n"; + return ss.str(); +} + +std::string P2PCommunication::toInlineString(int indent_size) const { + return toString(indent_size); +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, @@ -494,4 +541,62 @@ c10::intrusive_ptr postSingleCommunication( } } +namespace { + +c10::intrusive_ptr postSend( + P2PCommunication* communication, + DeviceIdxType my_device_index, + DeviceIdxType peer, + c10d::Backend* backend, + at::Tensor buffer, + int64_t tag) { + NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); + NVF_ERROR(peer != my_device_index, "Send to self at device: ", peer); + + // Needed to match ProcessGroup API + std::vector packed_buffer = {buffer}; + return backend->send( + packed_buffer, static_cast(peer), static_cast(tag)); +} + +c10::intrusive_ptr postRecv( + P2PCommunication* communication, + DeviceIdxType my_device_index, + DeviceIdxType peer, + c10d::Backend* backend, + at::Tensor buffer, + int64_t tag) { + NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); + NVF_ERROR(peer != my_device_index, "Recv to self at device: ", peer); + + // Needed to match ProcessGroup API + std::vector packed_buffer = {buffer}; + return backend->recv( + packed_buffer, static_cast(peer), static_cast(tag)); +} + +} // namespace + +c10::intrusive_ptr postSingleCommunication( + P2PCommunication* communication, + DeviceIdxType my_device_index, + DeviceIdxType peer, + c10d::Backend* backend, + at::Tensor buffer, + int64_t tag) { + NVF_ERROR(backend != nullptr); + + switch (communication->type()) { + case P2PCommunicationType::send: + return postSend( + communication, my_device_index, peer, backend, buffer, tag); + case P2PCommunicationType::recv: + return postRecv( + communication, my_device_index, peer, backend, buffer, tag); + default: + NVF_THROW("Wrong communication type: ", communication->type()); + return nullptr; + } +} + } // namespace nvfuser diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 3ed002b4702..b2688ad12de 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -111,6 +111,51 @@ class Communication : public Expr { void validate(); }; +enum class P2PCommunicationType { send, recv }; + +std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); + +class P2PCommunication : public Expr { + public: + using Expr::Expr; + + P2PCommunication( + IrBuilderPasskey passkey, + P2PCommunicationType type, + TensorView* buffer, + Val* peer, + Val* tag = nullptr); + + P2PCommunication(const P2PCommunication& other) = delete; + P2PCommunication& operator=(const P2PCommunication& other) = delete; + P2PCommunication(P2PCommunication&& other) = delete; + P2PCommunication& operator=(P2PCommunication&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "P2PCommunication"; + } + + P2PCommunicationType type() const { + return attribute(0); + } + + TensorView* buffer() const { + return input(0)->as(); + } + + Val* peer() const { + return attributeVal(1); + } + + Val* tag() const { + return attributeVal(2); + } +}; + // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by @@ -177,4 +222,12 @@ c10::intrusive_ptr postSingleCommunication( at::Tensor input_tensor, at::Tensor output_tensor); +c10::intrusive_ptr postSingleCommunication( + P2PCommunication* communication, + DeviceIdxType my_device_index, + DeviceIdxType peer, + c10d::Backend* backend, + at::Tensor buffer, + int64_t tag); + } // namespace nvfuser From f0b6509e1860cc16839e90b1d8b9b2a23d646461 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Sep 2024 10:38:11 -0700 Subject: [PATCH 02/11] patch hir::Wait to accept P2PCommunication --- csrc/host_ir/executor.cpp | 2 +- csrc/host_ir/host_ir.cpp | 9 +++++++-- csrc/host_ir/host_ir.h | 6 +++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 1f358227265..9b015932555 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -184,7 +184,7 @@ void HostIrExecutor::handle(Communication* communication) { } void HostIrExecutor::handle(Wait* wait) { - Communication* communication = wait->communication(); + Expr* communication = wait->communication(); NVF_ERROR(works_.find(communication) != works_.end(), "no wait req"); auto& work = works_.at(communication); if (work != nullptr) { diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 2f867006d1f..2c327b849e7 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -174,12 +175,16 @@ bool SetCurrentStream::sameAs(const Statement* other) const { return false; } -Wait::Wait(IrBuilderPasskey passkey, Communication* communication) - : Expr(passkey, {}, {}, {communication}) { +Wait::Wait(IrBuilderPasskey passkey, Expr* expr) + : Expr(passkey, {}, {}, {expr}) { NVF_ERROR( passkey.ir_container_->isA(), // NOLINT this, "must be registered in a HostIrContainer"); + NVF_ERROR( + (expr->isOneOf()), + expr, + "must be a Communication or a P2PCommunication"); } NVFUSER_DEFINE_CLONE_AND_CREATE(Wait) diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 279ac98c71d..50b36ccaa1a 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -164,7 +164,7 @@ class SetCurrentStream : public Expr { class Wait : public Expr { public: using Expr::Expr; - Wait(IrBuilderPasskey passkey, Communication* communication); + Wait(IrBuilderPasskey passkey, Expr* expr); Wait(const Wait& other) = delete; Wait& operator=(const Wait& other) = delete; @@ -181,8 +181,8 @@ class Wait : public Expr { bool sameAs(const Statement* other) const override; - Communication* communication() const { - return attributes_.at(0)->as(); + Expr* communication() const { + return attributes_.at(0)->as(); } }; From 99338e9c2cc61b4ab5fee7440ea56544f532fb8d Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Sep 2024 10:39:03 -0700 Subject: [PATCH 03/11] add support for P2PCommunication in HIRExecutor --- csrc/host_ir/executor.cpp | 17 +++++++++++++++++ csrc/host_ir/executor.h | 3 ++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 9b015932555..abaf476d30a 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -183,6 +183,23 @@ void HostIrExecutor::handle(Communication* communication) { output_tensor); } +void HostIrExecutor::handle(P2PCommunication* communication) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + at::Tensor buffer = + getKnownTensorOrUndefined(communication->buffer(), expr_evaluator_); + + works_[communication] = postSingleCommunication( + communication, + communicator_->deviceId(), + expr_evaluator_.evaluate(communication->peer()).as(), + communicator_->getWorld(), + buffer, + expr_evaluator_.evaluate(communication->tag()).as()); +} + void HostIrExecutor::handle(Wait* wait) { Expr* communication = wait->communication(); NVF_ERROR(works_.find(communication) != works_.end(), "no wait req"); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 54c0e80caff..fbf6041ef7f 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -77,6 +77,7 @@ class HostIrExecutor final : public OptInDispatch { void handle(SetCurrentStream* set_current_stream) override; void handle(PostOnStream* post_ir) override; void handle(Communication* communication) override; + void handle(P2PCommunication* communication) override; void handle(Wait* wait) override; void handle(ForLoop* for_loop) override; void handle(SliceOp* slice_op) override; @@ -93,7 +94,7 @@ class HostIrExecutor final : public OptInDispatch { std::unordered_map fec_; using StreamKey = std::variant; std::unordered_map streams_; - std::unordered_map> works_; + std::unordered_map> works_; }; } // namespace hir From 173de868552cf259334c390d684005655bc2222a Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Sep 2024 10:39:26 -0700 Subject: [PATCH 04/11] add P2PCommunication unit test --- tests/cpp/test_multidevice_host_ir.cpp | 55 ++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index f1df422ce83..3a5a62440b8 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -240,6 +240,61 @@ INSTANTIATE_TEST_SUITE_P( return s; }); +class P2PCommHostIrTest : public MultiDeviceTest {}; + +TEST_F(P2PCommHostIrTest, RingPairwiseExchange) { + constexpr int64_t kTensorSize = 1024; + const int64_t communicator_size = communicator_->size(); + const int64_t my_device_index = communicator_->deviceId(); + const int64_t send_peer = (my_device_index + 1) % communicator_size; + const int64_t recv_peer = + (communicator_size + my_device_index - 1) % communicator_size; + + auto hic = std::make_unique(); + FusionGuard::setCurFusion(hic.get()); + + TensorView* send_buffer = makeContigTensor(1); + TensorView* recv_buffer = makeContigTensor(1); + + auto* send = IrBuilder::create( + P2PCommunicationType::send, + send_buffer, + IrBuilder::create(send_peer)); + + auto* recv = IrBuilder::create( + P2PCommunicationType::recv, + recv_buffer, + IrBuilder::create(recv_peer)); + + auto* wait = IrBuilder::create(recv); + + hic->addInput(send_buffer); + hic->addOutput(recv_buffer); + + if (my_device_index == 0) { + hic->pushBackTopLevelExprs(send); + hic->pushBackTopLevelExprs(recv); + } else { + hic->pushBackTopLevelExprs(recv); + hic->pushBackTopLevelExprs(send); + } + hic->pushBackTopLevelExprs(wait); + + HostIrExecutor hie(std::move(hic), communicator_); + + auto options = at::TensorOptions().device(communicator_->device()); + at::Tensor send_buffer_aten = + at::randn(kTensorSize, options) + my_device_index; + at::Tensor recv_buffer_aten = at::empty(kTensorSize, options); + + auto outputs = hie.runWithInput( + {{send_buffer, send_buffer_aten}, {recv_buffer, recv_buffer_aten}}); + + // validate the obtained results + at::Tensor ref_output = send_buffer_aten + (recv_peer - my_device_index); + EXPECT_TRUE(torch::allclose(ref_output, outputs.back())); +} + } // namespace hir } // namespace nvfuser From 75159bfc5742bd52954333614e3e9897b0ec4238 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 19 Sep 2024 10:58:28 -0700 Subject: [PATCH 05/11] fix linter --- csrc/multidevice/communication.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b042f5305dd..5c7f0bc238f 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -232,7 +232,7 @@ P2PCommunication::P2PCommunication( Val* tag) : Expr(passkey) { if (tag == nullptr) { - tag = passkey.ir_container_->zeroVal(); + tag = passkey.ir_container_->zeroVal(); // NOLINT } addInput(buffer); From 8218061f63a445f8bf3e05143b614fb5e6eea7dc Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 27 Sep 2024 02:27:46 -0700 Subject: [PATCH 06/11] allow send/recv to self --- csrc/multidevice/communication.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 5c7f0bc238f..b7ebd9cd109 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -551,7 +551,6 @@ c10::intrusive_ptr postSend( at::Tensor buffer, int64_t tag) { NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); - NVF_ERROR(peer != my_device_index, "Send to self at device: ", peer); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer}; @@ -567,7 +566,6 @@ c10::intrusive_ptr postRecv( at::Tensor buffer, int64_t tag) { NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); - NVF_ERROR(peer != my_device_index, "Recv to self at device: ", peer); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer}; From 21ed5be2398ec5a1cb8eb5ccbfab6ecd588360a6 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 27 Sep 2024 02:33:20 -0700 Subject: [PATCH 07/11] add getSize to c10d::Backend mock definition --- csrc/multidevice/c10d_mock.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/multidevice/c10d_mock.h b/csrc/multidevice/c10d_mock.h index 90a631e3c02..45b5b65d1ca 100644 --- a/csrc/multidevice/c10d_mock.h +++ b/csrc/multidevice/c10d_mock.h @@ -150,6 +150,10 @@ class Backend : public torch::CustomClassHolder { const ReduceOptions& opts = ReduceOptions()) { return c10::make_intrusive(); } + + int getSize() const { + return 0; + } }; struct TCPStoreOptions { From 7bf46a8c71be0231308584cbfe81894502bd89c7 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 2 Oct 2024 19:41:24 +0300 Subject: [PATCH 08/11] minor comments --- csrc/multidevice/communication.cpp | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b7ebd9cd109..bac15e3a641 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -247,7 +247,7 @@ std::string P2PCommunication::toString(const int indent_size) const { std::stringstream ss; indent(ss, indent_size) << "P2PCommunication " << name() << " (" << "type=" << type() << ", " - << " buffer=" << buffer() << ", " + << "buffer=" << buffer() << ", " << "peer=" << peer() << ", " << "tag=" << tag() << ")\n"; return ss.str(); diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 3a5a62440b8..82e308949fc 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -240,7 +240,7 @@ INSTANTIATE_TEST_SUITE_P( return s; }); -class P2PCommHostIrTest : public MultiDeviceTest {}; +using P2PCommHostIrTest = MultiDeviceTest; TEST_F(P2PCommHostIrTest, RingPairwiseExchange) { constexpr int64_t kTensorSize = 1024; From 86c1b195cb80f91414b5217c4b25f28712a0c4bf Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 3 Oct 2024 02:20:46 -0700 Subject: [PATCH 09/11] remove tags --- csrc/host_ir/executor.cpp | 3 +-- csrc/multidevice/communication.cpp | 32 +++++++++--------------------- csrc/multidevice/communication.h | 10 ++-------- 3 files changed, 12 insertions(+), 33 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index abaf476d30a..975a306aa3d 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -196,8 +196,7 @@ void HostIrExecutor::handle(P2PCommunication* communication) { communicator_->deviceId(), expr_evaluator_.evaluate(communication->peer()).as(), communicator_->getWorld(), - buffer, - expr_evaluator_.evaluate(communication->tag()).as()); + buffer); } void HostIrExecutor::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index bac15e3a641..91e08a2b3f1 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -228,17 +228,11 @@ P2PCommunication::P2PCommunication( IrBuilderPasskey passkey, P2PCommunicationType type, TensorView* buffer, - Val* peer, - Val* tag) + Val* peer) : Expr(passkey) { - if (tag == nullptr) { - tag = passkey.ir_container_->zeroVal(); // NOLINT - } - addInput(buffer); addDataAttribute(type); addAttribute(peer); - addAttribute(tag); } NVFUSER_DEFINE_CLONE_AND_CREATE(P2PCommunication) @@ -248,8 +242,7 @@ std::string P2PCommunication::toString(const int indent_size) const { indent(ss, indent_size) << "P2PCommunication " << name() << " (" << "type=" << type() << ", " << "buffer=" << buffer() << ", " - << "peer=" << peer() << ", " - << "tag=" << tag() << ")\n"; + << "peer=" << peer() << ")\n"; return ss.str(); } @@ -548,14 +541,12 @@ c10::intrusive_ptr postSend( DeviceIdxType my_device_index, DeviceIdxType peer, c10d::Backend* backend, - at::Tensor buffer, - int64_t tag) { + at::Tensor buffer) { NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer}; - return backend->send( - packed_buffer, static_cast(peer), static_cast(tag)); + return backend->send(packed_buffer, static_cast(peer), /*tag=*/0); } c10::intrusive_ptr postRecv( @@ -563,14 +554,12 @@ c10::intrusive_ptr postRecv( DeviceIdxType my_device_index, DeviceIdxType peer, c10d::Backend* backend, - at::Tensor buffer, - int64_t tag) { + at::Tensor buffer) { NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer}; - return backend->recv( - packed_buffer, static_cast(peer), static_cast(tag)); + return backend->recv(packed_buffer, static_cast(peer), /*tag=*/0); } } // namespace @@ -580,17 +569,14 @@ c10::intrusive_ptr postSingleCommunication( DeviceIdxType my_device_index, DeviceIdxType peer, c10d::Backend* backend, - at::Tensor buffer, - int64_t tag) { + at::Tensor buffer) { NVF_ERROR(backend != nullptr); switch (communication->type()) { case P2PCommunicationType::send: - return postSend( - communication, my_device_index, peer, backend, buffer, tag); + return postSend(communication, my_device_index, peer, backend, buffer); case P2PCommunicationType::recv: - return postRecv( - communication, my_device_index, peer, backend, buffer, tag); + return postRecv(communication, my_device_index, peer, backend, buffer); default: NVF_THROW("Wrong communication type: ", communication->type()); return nullptr; diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index b2688ad12de..5a4b72d1de0 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -123,8 +123,7 @@ class P2PCommunication : public Expr { IrBuilderPasskey passkey, P2PCommunicationType type, TensorView* buffer, - Val* peer, - Val* tag = nullptr); + Val* peer); P2PCommunication(const P2PCommunication& other) = delete; P2PCommunication& operator=(const P2PCommunication& other) = delete; @@ -150,10 +149,6 @@ class P2PCommunication : public Expr { Val* peer() const { return attributeVal(1); } - - Val* tag() const { - return attributeVal(2); - } }; // The method "post" triggers the execution of the communication. This call is @@ -227,7 +222,6 @@ c10::intrusive_ptr postSingleCommunication( DeviceIdxType my_device_index, DeviceIdxType peer, c10d::Backend* backend, - at::Tensor buffer, - int64_t tag); + at::Tensor buffer); } // namespace nvfuser From 70b07592aba4506b8b54b7e6862bd236bffdd986 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 4 Oct 2024 03:02:29 -0700 Subject: [PATCH 10/11] minor comments --- csrc/multidevice/communication.cpp | 10 +++++----- csrc/multidevice/communication.h | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 91e08a2b3f1..143cf26f441 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -212,10 +212,10 @@ std::string Communication::toInlineString(int indent_size) const { std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type) { switch (type) { - case P2PCommunicationType::send: + case P2PCommunicationType::SEND: os << "send"; break; - case P2PCommunicationType::recv: + case P2PCommunicationType::RECV: os << "recv"; break; default: @@ -555,7 +555,7 @@ c10::intrusive_ptr postRecv( DeviceIdxType peer, c10d::Backend* backend, at::Tensor buffer) { - NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer); + NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer, ", which should be strictly smaller than the world size ", backend->getSize()); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer}; @@ -573,9 +573,9 @@ c10::intrusive_ptr postSingleCommunication( NVF_ERROR(backend != nullptr); switch (communication->type()) { - case P2PCommunicationType::send: + case P2PCommunicationType::SEND: return postSend(communication, my_device_index, peer, backend, buffer); - case P2PCommunicationType::recv: + case P2PCommunicationType::RECV: return postRecv(communication, my_device_index, peer, backend, buffer); default: NVF_THROW("Wrong communication type: ", communication->type()); diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 5a4b72d1de0..45c104b36d3 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -111,7 +111,7 @@ class Communication : public Expr { void validate(); }; -enum class P2PCommunicationType { send, recv }; +enum class P2PCommunicationType { SEND, RECV }; std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type); diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 82e308949fc..e97417894ef 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -257,12 +257,12 @@ TEST_F(P2PCommHostIrTest, RingPairwiseExchange) { TensorView* recv_buffer = makeContigTensor(1); auto* send = IrBuilder::create( - P2PCommunicationType::send, + P2PCommunicationType::SEND, send_buffer, IrBuilder::create(send_peer)); auto* recv = IrBuilder::create( - P2PCommunicationType::recv, + P2PCommunicationType::RECV, recv_buffer, IrBuilder::create(recv_peer)); From 6d4ab3c5d85a0efc84aede32b140da0e654b708b Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 4 Oct 2024 03:38:10 -0700 Subject: [PATCH 11/11] lint --- csrc/multidevice/communication.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 143cf26f441..29ef6995969 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -555,7 +555,12 @@ c10::intrusive_ptr postRecv( DeviceIdxType peer, c10d::Backend* backend, at::Tensor buffer) { - NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer, ", which should be strictly smaller than the world size ", backend->getSize()); + NVF_ERROR( + peer < backend->getSize(), + "invalid peer: ", + peer, + ", which should be strictly smaller than the world size ", + backend->getSize()); // Needed to match ProcessGroup API std::vector packed_buffer = {buffer};