Skip to content

Commit

Permalink
patch hir::Wait to accept P2PCommunication
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Sep 19, 2024
1 parent 3b27451 commit c48ea11
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void HostIrExecutor::handle(Communication* communication) {
}

void HostIrExecutor::handle(Wait* wait) {
Communication* communication = wait->communication();
Expr* communication = wait->communication();
NVF_ERROR(works_.find(communication) != works_.end(), "no wait req");
auto& work = works_.at(communication);
if (work != nullptr) {
Expand Down
9 changes: 7 additions & 2 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ir/printer.h>
#include <ir/utils.h>
#include <kernel_ir.h>
#include <multidevice/communication.h>
#include <ops/all_ops.h>

namespace nvfuser {
Expand Down Expand Up @@ -174,12 +175,16 @@ bool SetCurrentStream::sameAs(const Statement* other) const {
return false;
}

Wait::Wait(IrBuilderPasskey passkey, Communication* communication)
: Expr(passkey, {}, {}, {communication}) {
Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
: Expr(passkey, {}, {}, {expr}) {
NVF_ERROR(
passkey.ir_container_->isA<hir::HostIrContainer>(), // NOLINT
this,
"must be registered in a HostIrContainer");
NVF_ERROR(
(expr->isOneOf<Communication, P2PCommunication>()),
expr,
"must be a Communication or a P2PCommunication");
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)
Expand Down
6 changes: 3 additions & 3 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class SetCurrentStream : public Expr {
class Wait : public Expr {
public:
using Expr::Expr;
Wait(IrBuilderPasskey passkey, Communication* communication);
Wait(IrBuilderPasskey passkey, Expr* expr);

Wait(const Wait& other) = delete;
Wait& operator=(const Wait& other) = delete;
Expand All @@ -181,8 +181,8 @@ class Wait : public Expr {

bool sameAs(const Statement* other) const override;

Communication* communication() const {
return attributes_.at(0)->as<Communication>();
Expr* communication() const {
return attributes_.at(0)->as<Expr>();
}
};

Expand Down

0 comments on commit c48ea11

Please sign in to comment.