-
Notifications
You must be signed in to change notification settings - Fork 551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
F8I4 Grouped Gemm Optimization for Sparse M #3854
Open
jwfromm
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
jwfromm:export-D71510967
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+108
−46
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
This pull request was exported from Phabricator. Differential Revision: D71510967 |
Summary: X-link: facebookresearch/FBGEMM#945 In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower. Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably. As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change. Reviewed By: jiawenliu64 Differential Revision: D71510967
01ca3be
to
7b99508
Compare
This pull request was exported from Phabricator. Differential Revision: D71510967 |
jwfromm
added a commit
to jwfromm/FBGEMM
that referenced
this pull request
Mar 21, 2025
Summary: X-link: facebookresearch/FBGEMM#945 In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower. Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably. As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change. Reviewed By: jiawenliu64 Differential Revision: D71510967
jwfromm
added a commit
to jwfromm/FBGEMM
that referenced
this pull request
Mar 21, 2025
Summary: X-link: facebookresearch/FBGEMM#945 In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower. Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably. As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change. Reviewed By: jiawenliu64 Differential Revision: D71510967
jwfromm
added a commit
to jwfromm/FBGEMM
that referenced
this pull request
Mar 21, 2025
Summary: X-link: facebookresearch/FBGEMM#945 In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower. Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably. As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change. Reviewed By: jiawenliu64 Differential Revision: D71510967
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower.
Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably.
As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change.
Differential Revision: D71510967