Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host irs: add p2p comms #2970

Merged
merged 14 commits into from
Oct 4, 2024
3 changes: 2 additions & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class Val;
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(Communication); \
f(ForLoop);
f(ForLoop); \
f(P2PCommunication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
f(Allocate); \
f(Asm); \
Expand Down
18 changes: 17 additions & 1 deletion csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,24 @@ void HostIrExecutor::handle(Communication* communication) {
output_tensor);
}

void HostIrExecutor::handle(P2PCommunication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

at::Tensor buffer =
getKnownTensorOrUndefined(communication->buffer(), expr_evaluator_);

works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
expr_evaluator_.evaluate(communication->peer()).as<int64_t>(),
communicator_->getWorld(),
buffer);
}

void HostIrExecutor::handle(Wait* wait) {
Communication* communication = wait->communication();
Expr* communication = wait->communication();
NVF_ERROR(works_.find(communication) != works_.end(), "no wait req");
auto& work = works_.at(communication);
if (work != nullptr) {
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class HostIrExecutor final : public OptInDispatch {
void handle(Synchronize* synchronize) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
void handle(P2PCommunication* communication) override;
void handle(Wait* wait) override;
void handle(ForLoop* for_loop) override;
void handle(SliceOp* slice_op) override;
Expand Down
9 changes: 7 additions & 2 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ir/printer.h>
#include <ir/utils.h>
#include <kernel_ir.h>
#include <multidevice/communication.h>
#include <ops/all_ops.h>

namespace nvfuser {
Expand Down Expand Up @@ -178,13 +179,17 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

Wait::Wait(IrBuilderPasskey passkey, Communication* communication)
: Expr(passkey, {}, {}, {communication}) {
Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<HostIrContainer>(),
this,
"must be registered in a HostIrContainer");
NVF_ERROR(
(expr->isOneOf<Communication, P2PCommunication>()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but will you prefer having Communication, CollectiveCommunication (a subclass), and P2PCommunication? This seems to be more structured but adds one level of inheritance hierarchy, potentially making code less trackable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a thought. But I have the feeling that, in nvFuser in general, the infrastructure doesn't behave well with IR's hierarchical inheritance. E.g., I fear that it would break all the dispatch function.

Anyway, I am open to this idea, even though I cannot think of a concrete benefit right now.

expr,
"must be a Communication or a P2PCommunication");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)
Expand Down
6 changes: 3 additions & 3 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class SetCurrentStream : public Expr {
class Wait : public Expr {
public:
using Expr::Expr;
Wait(IrBuilderPasskey passkey, Communication* communication);
Wait(IrBuilderPasskey passkey, Expr* expr);

Wait(const Wait& other) = delete;
Wait& operator=(const Wait& other) = delete;
Expand All @@ -181,8 +181,8 @@ class Wait : public Expr {

bool sameAs(const Statement* other) const override;

Communication* communication() const {
return attributes_.at(0)->as<Communication>();
Expr* communication() const {
return attributes_.at(0)->as<Expr>();
}
};

Expand Down
4 changes: 4 additions & 0 deletions csrc/multidevice/c10d_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class Backend : public torch::CustomClassHolder {
const ReduceOptions& opts = ReduceOptions()) {
return c10::make_intrusive<Work>();
}

int getSize() const {
return 0;
}
};

struct TCPStoreOptions {
Expand Down
94 changes: 94 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,46 @@ std::string Communication::toInlineString(int indent_size) const {
return toString(indent_size);
}

std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type) {
switch (type) {
case P2PCommunicationType::SEND:
os << "send";
break;
case P2PCommunicationType::RECV:
os << "recv";
break;
default:
NVF_THROW("unrecognized P2PCommunicationType: ", type);
}
return os;
}

P2PCommunication::P2PCommunication(
IrBuilderPasskey passkey,
P2PCommunicationType type,
TensorView* buffer,
Val* peer)
: Expr(passkey) {
addInput(buffer);
addDataAttribute(type);
addAttribute(peer);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(P2PCommunication)

std::string P2PCommunication::toString(const int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "P2PCommunication " << name() << " ("
<< "type=" << type() << ", "
<< "buffer=" << buffer() << ", "
<< "peer=" << peer() << ")\n";
return ss.str();
}

std::string P2PCommunication::toInlineString(int indent_size) const {
return toString(indent_size);
}

namespace {
c10::intrusive_ptr<c10d::Work> postBroadcast(
Communication* communication,
Expand Down Expand Up @@ -494,4 +534,58 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
}
}

namespace {

c10::intrusive_ptr<c10d::Work> postSend(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
c10d::Backend* backend,
at::Tensor buffer) {
NVF_ERROR(peer < backend->getSize(), "invalid peer: ", peer);

// Needed to match ProcessGroup API
std::vector<at::Tensor> packed_buffer = {buffer};
return backend->send(packed_buffer, static_cast<int>(peer), /*tag=*/0);
}

c10::intrusive_ptr<c10d::Work> postRecv(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
c10d::Backend* backend,
at::Tensor buffer) {
NVF_ERROR(
peer < backend->getSize(),
"invalid peer: ",
peer,
", which should be strictly smaller than the world size ",
backend->getSize());

// Needed to match ProcessGroup API
std::vector<at::Tensor> packed_buffer = {buffer};
return backend->recv(packed_buffer, static_cast<int>(peer), /*tag=*/0);
}

} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
c10d::Backend* backend,
at::Tensor buffer) {
NVF_ERROR(backend != nullptr);

switch (communication->type()) {
case P2PCommunicationType::SEND:
return postSend(communication, my_device_index, peer, backend, buffer);
case P2PCommunicationType::RECV:
return postRecv(communication, my_device_index, peer, backend, buffer);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
}
}

} // namespace nvfuser
47 changes: 47 additions & 0 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ class Communication : public Expr {
void validate();
};

enum class P2PCommunicationType { SEND, RECV };

std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type);

class P2PCommunication : public Expr {
public:
using Expr::Expr;

P2PCommunication(
IrBuilderPasskey passkey,
P2PCommunicationType type,
TensorView* buffer,
Val* peer);

P2PCommunication(const P2PCommunication& other) = delete;
P2PCommunication& operator=(const P2PCommunication& other) = delete;
P2PCommunication(P2PCommunication&& other) = delete;
P2PCommunication& operator=(P2PCommunication&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "P2PCommunication";
}

P2PCommunicationType type() const {
return attribute<P2PCommunicationType>(0);
}

TensorView* buffer() const {
return input(0)->as<TensorView>();
}

Val* peer() const {
return attributeVal(1);
}
};

// The method "post" triggers the execution of the communication. This call is
// non-blocking. The communication can be posted multiple times.
// It is assumed that the current device_index (given by
Expand Down Expand Up @@ -177,4 +217,11 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
at::Tensor input_tensor,
at::Tensor output_tensor);

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
c10d::Backend* backend,
at::Tensor buffer);

} // namespace nvfuser
55 changes: 55 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,61 @@ INSTANTIATE_TEST_SUITE_P(
return s;
});

using P2PCommHostIrTest = MultiDeviceTest;

TEST_F(P2PCommHostIrTest, RingPairwiseExchange) {
constexpr int64_t kTensorSize = 1024;
const int64_t communicator_size = communicator_->size();
const int64_t my_device_index = communicator_->deviceId();
const int64_t send_peer = (my_device_index + 1) % communicator_size;
const int64_t recv_peer =
(communicator_size + my_device_index - 1) % communicator_size;

auto hic = std::make_unique<HostIrContainer>();
FusionGuard::setCurFusion(hic.get());

TensorView* send_buffer = makeContigTensor(1);
TensorView* recv_buffer = makeContigTensor(1);

auto* send = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::SEND,
send_buffer,
IrBuilder::create<Val>(send_peer));

auto* recv = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::RECV,
recv_buffer,
IrBuilder::create<Val>(recv_peer));

auto* wait = IrBuilder::create<Wait>(recv);

hic->addInput(send_buffer);
hic->addOutput(recv_buffer);

if (my_device_index == 0) {
hic->pushBackTopLevelExprs(send);
hic->pushBackTopLevelExprs(recv);
} else {
hic->pushBackTopLevelExprs(recv);
hic->pushBackTopLevelExprs(send);
}
hic->pushBackTopLevelExprs(wait);

HostIrExecutor hie(std::move(hic), communicator_);

auto options = at::TensorOptions().device(communicator_->device());
at::Tensor send_buffer_aten =
at::randn(kTensorSize, options) + my_device_index;
at::Tensor recv_buffer_aten = at::empty(kTensorSize, options);

auto outputs = hie.runWithInput(
{{send_buffer, send_buffer_aten}, {recv_buffer, recv_buffer_aten}});

// validate the obtained results
at::Tensor ref_output = send_buffer_aten + (recv_peer - my_device_index);
EXPECT_TRUE(torch::allclose(ref_output, outputs.back()));
}

} // namespace hir

} // namespace nvfuser
Loading