Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warp reduction in x and y dims #2966

Merged
merged 8 commits into from
Oct 7, 2024
Merged

warp reduction in x and y dims #2966

merged 8 commits into from
Oct 7, 2024

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Sep 19, 2024

Issue
The first part of innerOuter persistent kernel is an inner reduction, the inner dim is parallelized by vectorization, bdimx, bdimy, and persistent. Both bdimx and bdimy are used becuase we want to re-use bdimy to parallel the outer dim and re-use bdimx to parallel the inner dim in the second part of the kernel. However, it keeps us away from using warp reduction since the current runtime function only supports reduction in bimdx.
Fix
(1) Add another warp reduction runtime function to support reduction in x and y dims.
(2) bdimx and bdimy are explicitly set to static and passed to warp reduction as template paras.
(3) we are launching a cooperative kernel, gdimy is also static
Influence
(1) For a case with vectorization=8, bdimx=64, bdimy=4, and persistent=13 inner reduction dim is scheduled as:

rS195{13}, rUS197{1}, rthreadIdx.z198{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i1, 8) ), 64) ), 4) ), 13) ), 1) )}, rthreadIdx.y194{4}, rthreadIdx.x192{64}, rV190{8}

The heuristic ensures bdimz = inner dim size / vectorization / bdimx / bdimy / persistent == 1, this gives us a static warp reduction across bdimx and bdimy.
Performance:
A100 layer norm backward:

image

A100 RMS norm backward:
image

H100 layer norm backward
Local run, not included in CI.
image

Other hardware:
Local run, not included in CI.
link
Other options
(1) we can set bdimy to dynamic, then no need to use bdimz and can save one split. Then, to use warp reduction, needs to pad bdimy and ensure bdimx * padded_bdimy % 32 == 0. Aslo needs to revise the outer reduction part of the kernel, it was assuming bdimy is static.

@liqiangxl
Copy link
Collaborator Author

!build --diff --pybench

@jjsjann123 jjsjann123 self-requested a review September 19, 2024 19:04
@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl marked this pull request as ready for review October 2, 2024 02:04
std::pair<IterDomain*, IterDomain*> reduction_dims =
std::make_pair(reduction_on_xdim, nullptr);
if (reduction_on_xdim->hasPaddingToMultipleOfWarp()) {
return std::optional<std::pair<IterDomain*, IterDomain*>>(reduction_dims);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: return std::make_pair(reduction_on_xdim, nullptr)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised to

      return std::optional<std::pair<IterDomain*, IterDomain*>>(
          std::make_pair(reduction_on_xdim, nullptr));

I think we still want to keep this std::optional since the function is getMaybeWarpReductionDim, otherwise, we need to remove Maybe in the function name and make other changes in its callers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I wasn't suggesting to change the signature. It should automatically convert that to an std::optional. Is there any benefit of having that explicit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. I didn't realize it can auto convert to std::optional, changed.

if (reduction_on_xdim->extent()->isConstInt()) {
auto extent_value = reduction_on_xdim->extent()->evaluate();
if (extent_value % at::cuda::warp_size() == 0) {
return std::optional<std::pair<IterDomain*, IterDomain*>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

if ((extent_x_value * extent_y_value) % at::cuda::warp_size() == 0) {
std::pair<IterDomain*, IterDomain*> reduction_dims =
std::make_pair(reduction_on_xdim, reduction_on_ydim);
return std::optional<std::pair<IterDomain*, IterDomain*>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -611,6 +612,13 @@ std::unique_ptr<ReductionParams> innerOuterPersistentHeuristic(
rparams->block_dim_iter_dom = ParallelType::TIDy;
} else {
rparams->block_dim_inner_reduction_extra = ParallelType::TIDy;
rparams->static_bdimx = true;
rparams->static_bdimy = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we need to have all block dimensions be static, since bdimz == 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to parallelize rthreadIdx.z198{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i1, 8) ), 64) ), 4) ), 13) ), 1) )}, it is a dynamic dim depends on i1. Although the heuristics ensures it equals 1, we still can't set it to static, unless we use the value of i1, then the whole kernel becomes static kernel and we lost the dynamic shape support.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha, I see it now. Thanks for bearing with me here.

rparams->lparams = LaunchParams(
LaunchParams::UNINITIALIZED_VAL,
iop.gdimy,
LaunchParams::UNINITIALIZED_VAL,
iop.bdimx,
iop.bdimy,
LaunchParams::UNINITIALIZED_VAL);
iop.bdimz);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naive question regarding iop.bdimz. Why do we add this in heuristic if it's always going to be 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is redundant since we already have NVF_ERROR(iop.bdimz == 1, "bdimz must be 1.");
Removed this change.

@@ -88,4 +88,58 @@ __device__ void warpReduceTIDX(
}
}

template <int BDIMX, int BDIMY, bool Aligned, typename T, typename Func>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to have this template function with a slight different implementation for inter-warp reduction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing version warpReduceTIDX is doing reduction in X dim, Y dim is used for iteration. It requires bdimx % 32 ==0
This new version warpReduceTIDXY is doing reduction in X & Y dims, it requires (bdimx * bdimy) % 32 ==0. So they are different. Also bdimx and bdimy are static values, then we can use template vars.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, my bad. I didn't realized that in warpReduceTIDX, we are still using threadIdx.y. Earlier I was suggesting that these two should be merged together, but that doesn't look like worth it.

@jjsjann123
Copy link
Collaborator

The code change looks good to me.
One last question regarding the performance, with A100 layer norm backward, what's worst regression we are looking at? The axis doesn't give that.

@liqiangxl
Copy link
Collaborator Author

The code change looks good to me. One last question regarding the performance, with A100 layer norm backward, what's worst regression we are looking at? The axis doesn't give that.

double checked, it is 0.86x, link.
image

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl merged commit 615177d into main Oct 7, 2024
19 of 20 checks passed
@liqiangxl liqiangxl deleted the llu/warp_redu_xy branch October 7, 2024 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants