Skip to content

Commit

Permalink
add P2PCommunication unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Sep 19, 2024
1 parent f969cfa commit b923f54
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,61 @@ INSTANTIATE_TEST_SUITE_P(
return s;
});

class P2PCommHostIrTest : public MultiDeviceTest {};

TEST_F(P2PCommHostIrTest, RingPairwiseExchange) {
constexpr int64_t kTensorSize = 1024;
const int64_t communicator_size = communicator_->size();
const int64_t my_device_index = communicator_->deviceId();
const int64_t send_peer = (my_device_index + 1) % communicator_size;
const int64_t recv_peer =
(communicator_size + my_device_index - 1) % communicator_size;

auto hic = std::make_unique<HostIrContainer>();
FusionGuard::setCurFusion(hic.get());

TensorView* send_buffer = makeContigTensor(1);
TensorView* recv_buffer = makeContigTensor(1);

auto* send = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::send,
send_buffer,
IrBuilder::create<Val>(send_peer));

auto* recv = IrBuilder::create<P2PCommunication>(
P2PCommunicationType::recv,
recv_buffer,
IrBuilder::create<Val>(recv_peer));

auto* wait = IrBuilder::create<Wait>(recv);

hic->addInput(send_buffer);
hic->addOutput(recv_buffer);

if (my_device_index == 0) {
hic->pushBackTopLevelExprs(send);
hic->pushBackTopLevelExprs(recv);
} else {
hic->pushBackTopLevelExprs(recv);
hic->pushBackTopLevelExprs(send);
}
hic->pushBackTopLevelExprs(wait);

HostIrExecutor hie(std::move(hic), communicator_);

auto options = at::TensorOptions().device(communicator_->device());
at::Tensor send_buffer_aten =
at::randn(kTensorSize, options) + my_device_index;
at::Tensor recv_buffer_aten = at::empty(kTensorSize, options);

auto outputs = hie.runWithInput(
{{send_buffer, send_buffer_aten}, {recv_buffer, recv_buffer_aten}});

// validate the obtained results
at::Tensor ref_output = send_buffer_aten + (recv_peer - my_device_index);
EXPECT_TRUE(torch::allclose(ref_output, outputs.back()));
}

} // namespace hir

} // namespace nvfuser

0 comments on commit b923f54

Please sign in to comment.