Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept axis mapping when defining MmaOp #3391

Merged
merged 17 commits into from
Nov 12, 2024
Merged

Accept axis mapping when defining MmaOp #3391

merged 17 commits into from
Nov 12, 2024

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Nov 9, 2024

This keeps the default interface of fusedMultiplySum but also adds an option to provide an MmaOp::AxisMapping object. This mapping defines, for each output dimension, which axis in each operand (if any) corresponds to that output dimension.

This PR does not alter the behavior of mma_utils::MatmulPattern::translateToMmaOp meaning we still have BroadcastOp in translations for Hopper matmuls, but that change should be relatively simpler.

Fixes #3372

The included test only checks that dimensions are properly mapped in an MmaOp defined without broadcast axes. In followup PRs I plan to do the following:

  1. Demonstrate scheduling a Hopper matmul with unbroadcasted inputs manually. This should surface any bugs in the lowering of the MmaOp instruction when broadcasts are absent.
  2. Ensure that we don't depend on having broadcast dims in the Hopper matmul scheduler. For example, we will handle this case in moveInnerBroadcastLeft and we may also need to adjust the swizzling of the TMA smem load TensorView. At this point we will be able to automatically schedule an MmaOp without broadcasted inputs that has been manually defined using our automatic scheduler.
  3. Add an option MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true) and enable that in the Hopper matmul scheduler. At this point it will be safe for us to accept MatmulOp and LinearOp in the Hopper matmul scheduler.

@@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) {
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only needed for defining MmaOp::mAxes() and friends, but:

  1. Those methods are never used so I removed them and
  2. We can reconstruct that information easily using mma->axisMapping().

@jacobhinkle jacobhinkle changed the title [WIP] Accept axis mapping when defining MmaOp Accept axis mapping when defining MmaOp Nov 10, 2024
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

!test

This caused failures in split-K
@jacobhinkle
Copy link
Collaborator Author

!test

Using the wrong graph meant that we could not detect any id roles
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle jacobhinkle marked this pull request as ready for review November 11, 2024 01:48
@jacobhinkle
Copy link
Collaborator Author

After this PR, one thing we can do is specify the dimension order of the output of the MmaOp independently from the inputs. When we translate MatmulOp and LinearOp, the output already has logical order M, N and we are free to place K wherever we want, so I'll place it last. I think this will let us avoid using commitLeafToLogical like is done here:

tv2->commitLeafToLogical();

So in that case we can see how the AxisMapping is standing in for a root->logical reordering. Since there is one for each input operand this feels like another nice use case for read/write/compute domains as suggested by @zasdfgbnm for indexing ldmatrix.

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
// corresponding position of each output axis in either the A or B input.
// Positions are absolute and refer to the noReductions logical domain. NOTE:
// -1 indicates that the axis does not exist, so Broadcast and Reduction
// dimensions should not have position -1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an example? For example, what does it mean if we have:

a_axes = {0, 2, 1, 3}
b_axes = {1, 0 , 3, 2}

does it mean a.axis(0) is mapped to b.axis(1)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll put this in a comment but yes, this would mean we would map the following sets:

{a.axis(0), b.axis(1), out.axis(0)}
{a.axis(2), b.axis(0), out.axis(1)}
{a.axis(1), b.axis(3), out.axis(2)}
{a.axis(3), b.axis(2), out.axis(3)}

csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
const mma_utils::MatmulPattern& pattern = patterns.front();

IdModel id_model(&fusion);
const ValGraph& permissive_graph =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use broadcast graph instead of permissive graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually planning to do this, but it means I need to update mma_utils since permissive is still used there, so I will do that in a follow-up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

csrc/ir/internal_nodes.h Outdated Show resolved Hide resolved
csrc/ops/arith.cpp Outdated Show resolved Hide resolved
@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle jacobhinkle merged commit 030c2ba into main Nov 12, 2024
13 of 14 checks passed
@jacobhinkle jacobhinkle deleted the mmaop_no_broadcast branch November 12, 2024 20:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enable MmaOp to receive unbroadcasted inputs
2 participants