-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DotOperandEncodingAttr::getSizePerThread
with block layout parent
#5863
Conversation
…as parent Signed-off-by: Anatoly Myachev <[email protected]>
lib/Dialect/TritonGPU/IR/Dialect.cpp
Outdated
@@ -2213,10 +2213,11 @@ SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const { | |||
assert(parentLayout && "DotOperandEncodingAttr must have a parent"); | |||
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) { | |||
return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx()); | |||
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(parentLayout)) { | |||
return expandMatrixShapeWithBatch(ArrayRef(blocked.getSizePerThread())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezcano I thought this method is going to be deprecated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We plan to implement everything in terms of LinearLayouts, yes, but for now I'm happy to take this to unblock them. It'd be nice to add a test in DialectTest.cpp
LinearEncodingTest
, making sure that this implementation agrees, at least in principle, with the LinearEncoding implementation.
Hi @lezcano, I'm trying to write a test but I get the following error. Could you tell me if this is my mistake in writing the test or if this is a current layout incompatibility? /project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: Expected equality of these values:
/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: distributedEncoding.getSizePerThread()
/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: Which is: { 1, 4, 4 }
/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: linearEncoding.getSizePerThread()
/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: Which is: { 4, 1 }
The changes I ran the test with: Patchdiff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp
index 797b55d43..1c7d16340 100644
--- a/unittest/Dialect/TritonGPU/DialectTest.cpp
+++ b/unittest/Dialect/TritonGPU/DialectTest.cpp
@@ -502,6 +502,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}),
};
SmallVector<triton::gpu::DistributedEncodingTrait> distributedEncodings;
+ SmallVector<triton::gpu::DistributedEncodingTrait> distributedEncodings2;
// Create BlockedEncodingAttr and SliceEncodingAttr
{
@@ -516,6 +517,11 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
distributedEncodings.push_back(blockedEncoding);
distributedEncodings.push_back(
triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding));
+ // Create an opIdx=0 and opIdx=1 encoding
+ for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
+ distributedEncodings2.push_back(
+ triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, blockedEncoding, 0));
+ }
}
}
}
@@ -538,6 +544,30 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
}
}
+ for (const auto &distributedEncoding : distributedEncodings2) {
+ for (auto shape : shapes) {
+ if (auto sliceEncoding =
+ dyn_cast<triton::gpu::SliceEncodingAttr>(distributedEncoding)) {
+ shape.erase(shape.begin() + sliceEncoding.getDim());
+ }
+
+ // Create LinearEncodingAttr from the LinearLayout
+ auto linearLayout = distributedEncoding.toLinearLayout(shape);
+ auto linearEncoding =
+ triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout);
+
+ if (auto layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
+ if (isa<triton::gpu::BlockedEncodingAttr>(layout.getParent())) {
+ // FIXME: This happens to be correct for SliceLayout because of the hack
+ // in SliceEncodingAttr::toLinearLayout(). We should remove the hack
+ // and the skips in the getWarpsPerCTA() and getThreadsPerWarp()
+ ASSERT_EQ(distributedEncoding.getSizePerThread(),
+ linearEncoding.getSizePerThread());
+ }
+ }
+ }
+ }
+
for (const auto &distributedEncoding : distributedEncodings) {
for (auto shape : shapes) {
if (auto sliceEncoding =
|
At first sight, it looks like the LinearLayout is more correct to my view. For starters, it has the right In general, look at the structure of the linear layout and see if it returns what it should, then compare with what the legacy path returns. |
Signed-off-by: Anatoly Myachev <[email protected]>
Thanks for looking! Right, it was because of using |
// Create an opIdx=0 and opIdx=1 encoding | ||
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { | ||
distributedEncodings2.push_back( | ||
triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, | ||
blockedEncoding, 0)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind adding it to the previous vector? If it's failing some other tests feel free to "skip them" using if statements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
… skip failed tests Signed-off-by: Anatoly Myachev <[email protected]>
@@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |||
// Test that methods of DistributedEncoding return the same values | |||
Type eltTy = Float32Type::get(&ctx); | |||
|
|||
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); | |||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | |||
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong order
ASSERT_EQ(distributedEncoding.getContigPerThread(), | ||
linearEncoding.getContigPerThread()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getRepOrder(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong order
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getRepOrder(), | ||
linearEncoding.getRepOrder()); | ||
ASSERT_EQ(distributedEncoding.getContigPerThread(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm::SmallVector<unsigned int> mlir::triton::gpu::DotOperandEncodingAttr::getContigPerThread(): Assertion kWidth !=0 && "Do not support kWidth=0"' failed.
ASSERT_EQ(distributedEncoding.getThreadOrder(), | ||
linearEncoding.getThreadOrder()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getThreadOrder(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong order: { 1, 0 }
vs { 0, 1 }
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), | ||
linearEncoding.getThreadsPerWarp()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM ERROR: getThreadsPerWarp not implemented for DotOperandEncodingAttr
@@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |||
// If we are not using CGAs, the order is meaningless | |||
auto useCGA = | |||
baseEncoding.getCTAsPerCGA() != SmallVector<unsigned>(rank, 1); | |||
if (useCGA) { | |||
if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong order: { 1, 0 }
vs { 0, 1 }
Could all the order issues be because the LinearEncodingAttr is computing the correct order for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, that's a preexisting issue. Better to have skips than not testing at all! Thank you
@lezcano thank you for the review and suggestions! |
For XPU backend, the logic of the common code is slightly changed and some Triton lit tests encounter the problem of an unimplemented function.