You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
F8I4 Grouped Gemm Optimization for Sparse M (pytorch#3854)
Summary:
X-link: facebookresearch/FBGEMM#945
In cases where there are many groups, but few have a non-zero amount of routed tokens, it turns out we pay a high overhead. For example if a single token is routed to one of 128 experts, while the compute is the same as 1 token being routed to one expert the runtime is much lower.
Presumably there are some kernel inefficiencies involved in looping over the empty groups. This diff changes how kernel arguments are set up so that we do grouped gemm over min(total_M, groups). This allows us to ignore many of the groups where no compute is required and improves performance in those cases considerably.
As an example of the effect of this diff, when total_M is 1 and there are 128 groups, latency will be 3X smaller thanks to this change.
Reviewed By: jiawenliu64
Differential Revision: D71510967
0 commit comments