-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement DotOperandEncodingAttr::getSizePerThread
with block layout parent
#5863
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -516,6 +516,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |
distributedEncodings.push_back(blockedEncoding); | ||
distributedEncodings.push_back( | ||
triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); | ||
// Create an opIdx=0 and opIdx=1 encoding | ||
for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { | ||
distributedEncodings.push_back( | ||
triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, | ||
blockedEncoding, 0)); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -538,6 +544,12 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |
} | ||
} | ||
|
||
auto is_dot_op_with_block_parent = [](Attribute layout) { | ||
auto dot_layout = dyn_cast<triton::gpu::DotOperandEncodingAttr>(layout); | ||
return dot_layout && | ||
isa<triton::gpu::BlockedEncodingAttr>(dot_layout.getParent()); | ||
}; | ||
|
||
for (const auto &distributedEncoding : distributedEncodings) { | ||
for (auto shape : shapes) { | ||
if (auto sliceEncoding = | ||
|
@@ -558,29 +570,37 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |
// Test that methods of DistributedEncoding return the same values | ||
Type eltTy = Float32Type::get(&ctx); | ||
|
||
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); | ||
} | ||
ASSERT_EQ(distributedEncoding.getTotalElemsPerThread(shape), | ||
linearEncoding.getTotalElemsPerThread(shape)); | ||
ASSERT_EQ(distributedEncoding.getElemsPerThread(shape), | ||
linearEncoding.getElemsPerThread(shape)); | ||
ASSERT_EQ(distributedEncoding.getRepOrder(), | ||
linearEncoding.getRepOrder()); | ||
ASSERT_EQ(distributedEncoding.getContigPerThread(), | ||
linearEncoding.getContigPerThread()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getRepOrder(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrong order |
||
linearEncoding.getRepOrder()); | ||
ASSERT_EQ(distributedEncoding.getContigPerThread(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
linearEncoding.getContigPerThread()); | ||
} | ||
// DotOperandEncodingAttr::getWarpOrder() is not defined | ||
if (!isa<triton::gpu::DotOperandEncodingAttr>(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getWarpOrder(), | ||
linearEncoding.getWarpOrder()); | ||
} | ||
ASSERT_EQ(distributedEncoding.getThreadOrder(), | ||
linearEncoding.getThreadOrder()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getThreadOrder(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrong order: |
||
linearEncoding.getThreadOrder()); | ||
} | ||
// For slice these do not equal the total number of lines / warps | ||
// See [Note. Divergence of methods wrt. legacy layouts] | ||
if (!isa<triton::gpu::SliceEncodingAttr>(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), | ||
linearEncoding.getWarpsPerCTA()); | ||
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), | ||
linearEncoding.getThreadsPerWarp()); | ||
if (!is_dot_op_with_block_parent(distributedEncoding)) { | ||
ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
linearEncoding.getThreadsPerWarp()); | ||
} | ||
} | ||
// Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes | ||
// the second repetition along K as the second tile. | ||
|
@@ -602,7 +622,7 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { | |
// If we are not using CGAs, the order is meaningless | ||
auto useCGA = | ||
baseEncoding.getCTAsPerCGA() != SmallVector<unsigned>(rank, 1); | ||
if (useCGA) { | ||
if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrong order: |
||
ASSERT_EQ(baseEncoding.getCTAOrder(), linearEncoding.getCTAOrder()); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong order