Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DotOperandEncodingAttr::getSizePerThread with block layout parent #5863

Merged
merged 3 commits into from
Feb 11, 2025

Conversation

anmyachev
Copy link
Contributor

For XPU backend, the logic of the common code is slightly changed and some Triton lit tests encounter the problem of an unimplemented function.

@@ -2213,10 +2213,11 @@ SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const {
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto parentMmaLayout = mlir::dyn_cast<MmaEncodingTrait>(parentLayout)) {
return parentMmaLayout.getSizePerThreadForOperand(getKWidth(), getOpIdx());
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(parentLayout)) {
return expandMatrixShapeWithBatch(ArrayRef(blocked.getSizePerThread()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezcano I thought this method is going to be deprecated?

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to implement everything in terms of LinearLayouts, yes, but for now I'm happy to take this to unblock them. It'd be nice to add a test in DialectTest.cpp LinearEncodingTest, making sure that this implementation agrees, at least in principle, with the LinearEncoding implementation.

@anmyachev
Copy link
Contributor Author

We plan to implement everything in terms of LinearLayouts, yes, but for now I'm happy to take this to unblock them. It'd be nice to add a test in DialectTest.cpp LinearEncodingTest, making sure that this implementation agrees, at least in principle, with the LinearEncoding implementation.

Hi @lezcano,

I'm trying to write a test but I get the following error. Could you tell me if this is my mistake in writing the test or if this is a current layout incompatibility?

/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31: Expected equality of these values:

/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31:   distributedEncoding.getSizePerThread()

/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31:     Which is: { 1, 4, 4 }

/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31:   linearEncoding.getSizePerThread()

/project/CMake-src/Source/CTest/cmCTestRunTest.cxx:43 31:     Which is: { 4, 1 }

The changes I ran the test with:

Patch
diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp
index 797b55d43..1c7d16340 100644
--- a/unittest/Dialect/TritonGPU/DialectTest.cpp
+++ b/unittest/Dialect/TritonGPU/DialectTest.cpp
@@ -502,6 +502,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
       triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}),
   };
   SmallVector<triton::gpu::DistributedEncodingTrait> distributedEncodings;
+  SmallVector<triton::gpu::DistributedEncodingTrait> distributedEncodings2;
 
   // Create BlockedEncodingAttr and SliceEncodingAttr
   {
@@ -516,6 +517,11 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
         distributedEncodings.push_back(blockedEncoding);
         distributedEncodings.push_back(
             triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding));
+        // Create an opIdx=0 and opIdx=1 encoding
+        for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
+          distributedEncodings2.push_back(
+            triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, blockedEncoding, 0));
+        }
       }
     }
   }
@@ -538,6 +544,30 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
     }
   }
 
+  for (const auto &distributedEncoding : distributedEncodings2) {
+    for (auto shape : shapes) {
+      if (auto sliceEncoding =
+              dyn_cast<triton::gpu::SliceEncodingAttr>(distributedEncoding)) {
+        shape.erase(shape.begin() + sliceEncoding.getDim());
+      }
+
+      // Create LinearEncodingAttr from the LinearLayout
+      auto linearLayout = distributedEncoding.toLinearLayout(shape);
+      auto linearEncoding =
+          triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout);
+
+        if (auto layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) {
+          if (isa<triton::gpu::BlockedEncodingAttr>(layout.getParent())) {
+              // FIXME: This happens to be correct for SliceLayout because of the hack
+              // in SliceEncodingAttr::toLinearLayout(). We should remove the hack
+              // and the skips in the getWarpsPerCTA() and getThreadsPerWarp()
+              ASSERT_EQ(distributedEncoding.getSizePerThread(),
+                        linearEncoding.getSizePerThread());
+          }
+        }
+    }
+  }
+
   for (const auto &distributedEncoding : distributedEncodings) {
     for (auto shape : shapes) {
       if (auto sliceEncoding =

@lezcano
Copy link
Contributor

lezcano commented Feb 10, 2025

At first sight, it looks like the LinearLayout is more correct to my view. For starters, it has the right rank, right? Difficult to tell without knowing the blocked layout used.

In general, look at the structure of the linear layout and see if it returns what it should, then compare with what the legacy path returns.

@anmyachev
Copy link
Contributor Author

For starters, it has the right rank, right?

Thanks for looking! Right, it was because of using expandMatrixShapeWithBatch (removed it). Now the difference is: { 4, 4 } vs { 4, 1 }.

Comment on lines 520 to 525
// Create an opIdx=0 and opIdx=1 encoding
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) {
distributedEncodings2.push_back(
triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx,
blockedEncoding, 0));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding it to the previous vector? If it's failing some other tests feel free to "skip them" using if statements

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
// Test that methods of DistributedEncoding return the same values
Type eltTy = Float32Type::get(&ctx);

ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong order

ASSERT_EQ(distributedEncoding.getContigPerThread(),
linearEncoding.getContigPerThread());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getRepOrder(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong order

if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getRepOrder(),
linearEncoding.getRepOrder());
ASSERT_EQ(distributedEncoding.getContigPerThread(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::SmallVector<unsigned int> mlir::triton::gpu::DotOperandEncodingAttr::getContigPerThread(): Assertion kWidth !=0 && "Do not support kWidth=0"' failed.

ASSERT_EQ(distributedEncoding.getThreadOrder(),
linearEncoding.getThreadOrder());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getThreadOrder(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong order: { 1, 0 } vs { 0, 1 }

ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
linearEncoding.getThreadsPerWarp());
if (!is_dot_op_with_block_parent(distributedEncoding)) {
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLVM ERROR: getThreadsPerWarp not implemented for DotOperandEncodingAttr

@@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) {
// If we are not using CGAs, the order is meaningless
auto useCGA =
baseEncoding.getCTAsPerCGA() != SmallVector<unsigned>(rank, 1);
if (useCGA) {
if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong order: { 1, 0 } vs { 0, 1 }

@lezcano
Copy link
Contributor

lezcano commented Feb 11, 2025

Could all the order issues be because the LinearEncodingAttr is computing the correct order for opIdx=1 while the legacy path is not?

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, that's a preexisting issue. Better to have skips than not testing at all! Thank you

@lezcano lezcano enabled auto-merge (squash) February 11, 2025 16:10
@lezcano lezcano merged commit 6afc767 into triton-lang:main Feb 11, 2025
7 checks passed
@anmyachev anmyachev deleted the get-size-per-thread branch February 11, 2025 18:40
@anmyachev
Copy link
Contributor Author

@lezcano thank you for the review and suggestions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants