Skip to content

Commit

Permalink
add support for P2PCommunication in HIRExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Sep 19, 2024
1 parent c48ea11 commit f969cfa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
17 changes: 17 additions & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,23 @@ void HostIrExecutor::handle(Communication* communication) {
output_tensor);
}

void HostIrExecutor::handle(P2PCommunication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

at::Tensor buffer =
getKnownTensorOrUndefined(communication->buffer(), expr_evaluator_);

works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
expr_evaluator_.evaluate(communication->peer()).as<int64_t>(),
communicator_->getWorld(),
buffer,
expr_evaluator_.evaluate(communication->tag()).as<int64_t>());
}

void HostIrExecutor::handle(Wait* wait) {
Expr* communication = wait->communication();
NVF_ERROR(works_.find(communication) != works_.end(), "no wait req");
Expand Down
3 changes: 2 additions & 1 deletion csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class HostIrExecutor final : public OptInDispatch {
void handle(SetCurrentStream* set_current_stream) override;
void handle(PostOnStream* post_ir) override;
void handle(Communication* communication) override;
void handle(P2PCommunication* communication) override;
void handle(Wait* wait) override;
void handle(ForLoop* for_loop) override;
void handle(SliceOp* slice_op) override;
Expand All @@ -93,7 +94,7 @@ class HostIrExecutor final : public OptInDispatch {
std::unordered_map<HostUnit*, FusionExecutorCache> fec_;
using StreamKey = std::variant<int64_t, Stream*>;
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Communication*, c10::intrusive_ptr<c10d::Work>> works_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
};

} // namespace hir
Expand Down

0 comments on commit f969cfa

Please sign in to comment.