Skip to content

Commit

Permalink
Implement DotOperandEncodingAttr::getSizePerThread with block layou…
Browse files Browse the repository at this point in the history
…t parent (#5863)

For XPU backend, the logic of the common code is slightly changed and
some Triton lit tests encounter the problem of an unimplemented
function.

---------

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Feb 11, 2025
1 parent cdf49bf commit 6afc767
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
5 changes: 3 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,10 +2213,11 @@ SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const {
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx());
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(parentLayout)) {
return blocked.getSizePerThread();
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-NvidiaMmaEncodingAttr parent not "
"supported yet");
"getSizePerThread not implemented for DotOperandEncodingAttr");
return {};
}
}
Expand Down
40 changes: 30 additions & 10 deletions unittest/Dialect/TritonGPU/DialectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
distributedEncodings.push_back(blockedEncoding);
distributedEncodings.push_back(
triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding));
// Create an opIdx=0 and opIdx=1 encoding
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
distributedEncodings.push_back(
triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx,
blockedEncoding, 0));
}
}
}
}
Expand All @@ -538,6 +544,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
}
}

auto is_dot_op_with_block_parent = [](Attribute layout) {
auto dot_layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout);
return dot_layout &&
isa<triton::gpu::BlockedEncodingAttr>(dot_layout.getParent());
};

for (const auto &distributedEncoding : distributedEncodings) {
for (auto shape : shapes) {
if (auto sliceEncoding =
Expand All @@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
// Test that methods of DistributedEncoding return the same values
Type eltTy = Float32Type::get(&ctx);

ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
}
ASSERT_EQ(distributedEncoding.getTotalElemsPerThread(shape),
linearEncoding.getTotalElemsPerThread(shape));
ASSERT_EQ(distributedEncoding.getElemsPerThread(shape),
linearEncoding.getElemsPerThread(shape));
ASSERT_EQ(distributedEncoding.getRepOrder(),
linearEncoding.getRepOrder());
ASSERT_EQ(distributedEncoding.getContigPerThread(),
linearEncoding.getContigPerThread());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getRepOrder(),
linearEncoding.getRepOrder());
ASSERT_EQ(distributedEncoding.getContigPerThread(),
linearEncoding.getContigPerThread());
}
// DotOperandEncodingAttr::getWarpOrder() is not defined
if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getWarpOrder(),
linearEncoding.getWarpOrder());
}
ASSERT_EQ(distributedEncoding.getThreadOrder(),
linearEncoding.getThreadOrder());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getThreadOrder(),
linearEncoding.getThreadOrder());
}
// For slice these do not equal the total number of lines / warps
// See [Note. Divergence of methods wrt. legacy layouts]
if (!isa<triton::gpu::SliceEncodingAttr>(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getWarpsPerCTA(),
linearEncoding.getWarpsPerCTA());
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
linearEncoding.getThreadsPerWarp());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
linearEncoding.getThreadsPerWarp());
}
}
// Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes
// the second repetition along K as the second tile.
Expand All @@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
// If we are not using CGAs, the order is meaningless
auto useCGA =
baseEncoding.getCTAsPerCGA() != SmallVector<unsigned>(rank, 1);
if (useCGA) {
if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(baseEncoding.getCTAOrder(), linearEncoding.getCTAOrder());
}
}
Expand Down

0 comments on commit 6afc767

Please sign in to comment.