-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept axis mapping when defining MmaOp #3391
Conversation
@@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) { | |||
} | |||
|
|||
} // namespace nvfuser::ir_utils | |||
|
|||
namespace nvfuser::MmaOpUtils { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was only needed for defining MmaOp::mAxes()
and friends, but:
- Those methods are never used so I removed them and
- We can reconstruct that information easily using
mma->axisMapping()
.
!test |
!test |
This caused failures in split-K
!test |
Using the wrong graph meant that we could not detect any id roles
!test |
After this PR, one thing we can do is specify the dimension order of the output of the Line 603 in d34553f
So in that case we can see how the AxisMapping is standing in for a root->logical reordering. Since there is one for each input operand this feels like another nice use case for read/write/compute domains as suggested by @zasdfgbnm for indexing ldmatrix.
|
csrc/ir/internal_nodes.h
Outdated
// corresponding position of each output axis in either the A or B input. | ||
// Positions are absolute and refer to the noReductions logical domain. NOTE: | ||
// -1 indicates that the axis does not exist, so Broadcast and Reduction | ||
// dimensions should not have position -1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide an example? For example, what does it mean if we have:
a_axes = {0, 2, 1, 3}
b_axes = {1, 0 , 3, 2}
does it mean a.axis(0)
is mapped to b.axis(1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll put this in a comment but yes, this would mean we would map the following sets:
{a.axis(0), b.axis(1), out.axis(0)}
{a.axis(2), b.axis(0), out.axis(1)}
{a.axis(1), b.axis(3), out.axis(2)}
{a.axis(3), b.axis(2), out.axis(3)}
const mma_utils::MatmulPattern& pattern = patterns.front(); | ||
|
||
IdModel id_model(&fusion); | ||
const ValGraph& permissive_graph = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use broadcast graph instead of permissive graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am actually planning to do this, but it means I need to update mma_utils since permissive is still used there, so I will do that in a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!build |
!build |
This keeps the default interface of
fusedMultiplySum
but also adds an option to provide anMmaOp::AxisMapping
object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension.This PR does not alter the behavior of
mma_utils::MatmulPattern::translateToMmaOp
meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler.Fixes #3372
The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following:
moveInnerBroadcastLeft
and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule anMmaOp
without broadcasted inputs that has been manually defined using our automatic scheduler.MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true)
and enable that in the Hopper matmul scheduler. At this point it will be safe for us to acceptMatmulOp
andLinearOp
in the Hopper matmul scheduler.