-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
warp reduction in x and y dims #2966
Conversation
!build --diff --pybench |
!build |
csrc/device_lower/utils.cpp
Outdated
std::pair<IterDomain*, IterDomain*> reduction_dims = | ||
std::make_pair(reduction_on_xdim, nullptr); | ||
if (reduction_on_xdim->hasPaddingToMultipleOfWarp()) { | ||
return std::optional<std::pair<IterDomain*, IterDomain*>>(reduction_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: return std::make_pair(reduction_on_xdim, nullptr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revised to
return std::optional<std::pair<IterDomain*, IterDomain*>>(
std::make_pair(reduction_on_xdim, nullptr));
I think we still want to keep this std::optional
since the function is getMaybeWarpReductionDim
, otherwise, we need to remove Maybe
in the function name and make other changes in its callers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I wasn't suggesting to change the signature. It should automatically convert that to an std::optional. Is there any benefit of having that explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got you. I didn't realize it can auto convert to std::optional
, changed.
csrc/device_lower/utils.cpp
Outdated
if (reduction_on_xdim->extent()->isConstInt()) { | ||
auto extent_value = reduction_on_xdim->extent()->evaluate(); | ||
if (extent_value % at::cuda::warp_size() == 0) { | ||
return std::optional<std::pair<IterDomain*, IterDomain*>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
csrc/device_lower/utils.cpp
Outdated
if ((extent_x_value * extent_y_value) % at::cuda::warp_size() == 0) { | ||
std::pair<IterDomain*, IterDomain*> reduction_dims = | ||
std::make_pair(reduction_on_xdim, reduction_on_ydim); | ||
return std::optional<std::pair<IterDomain*, IterDomain*>>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -611,6 +612,13 @@ std::unique_ptr<ReductionParams> innerOuterPersistentHeuristic( | |||
rparams->block_dim_iter_dom = ParallelType::TIDy; | |||
} else { | |||
rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; | |||
rparams->static_bdimx = true; | |||
rparams->static_bdimy = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we need to have all block dimensions be static, since bdimz == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used to parallelize rthreadIdx.z198{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i1, 8) ), 64) ), 4) ), 13) ), 1) )}
, it is a dynamic dim depends on i1
. Although the heuristics ensures it equals 1, we still can't set it to static, unless we use the value of i1
, then the whole kernel becomes static kernel and we lost the dynamic shape support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha, I see it now. Thanks for bearing with me here.
rparams->lparams = LaunchParams( | ||
LaunchParams::UNINITIALIZED_VAL, | ||
iop.gdimy, | ||
LaunchParams::UNINITIALIZED_VAL, | ||
iop.bdimx, | ||
iop.bdimy, | ||
LaunchParams::UNINITIALIZED_VAL); | ||
iop.bdimz); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naive question regarding iop.bdimz. Why do we add this in heuristic if it's always going to be 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is redundant since we already have NVF_ERROR(iop.bdimz == 1, "bdimz must be 1.");
Removed this change.
@@ -88,4 +88,58 @@ __device__ void warpReduceTIDX( | |||
} | |||
} | |||
|
|||
template <int BDIMX, int BDIMY, bool Aligned, typename T, typename Func> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to have this template function with a slight different implementation for inter-warp reduction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing version warpReduceTIDX
is doing reduction in X dim, Y dim is used for iteration. It requires bdimx % 32 ==0
This new version warpReduceTIDXY
is doing reduction in X & Y dims, it requires (bdimx * bdimy) % 32 ==0
. So they are different. Also bdimx
and bdimy
are static values, then we can use template vars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, my bad. I didn't realized that in warpReduceTIDX, we are still using threadIdx.y. Earlier I was suggesting that these two should be merged together, but that doesn't look like worth it.
The code change looks good to me. |
double checked, it is 0.86x, link. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!build |
Issue
The first part of innerOuter persistent kernel is an inner reduction, the inner dim is parallelized by
vectorization
,bdimx
,bdimy
, andpersistent
. Bothbdimx
andbdimy
are used becuase we want to re-usebdimy
to parallel the outer dim and re-usebdimx
to parallel the inner dim in the second part of the kernel. However, it keeps us away from using warp reduction since the current runtime function only supports reduction inbimdx
.Fix
(1) Add another warp reduction runtime function to support reduction in x and y dims.
(2)
bdimx
andbdimy
are explicitly set to static and passed to warp reduction as template paras.(3) we are launching a cooperative kernel,
gdimy
is also staticInfluence
(1) For a case with
vectorization=8
,bdimx=64
,bdimy=4
, andpersistent=13
inner reduction dim is scheduled as:The heuristic ensures
bdimz = inner dim size / vectorization / bdimx / bdimy / persistent == 1
, this gives us a static warp reduction across bdimx and bdimy.Performance:
A100 layer norm backward:
A100 RMS norm backward:
H100 layer norm backward
Local run, not included in CI.
Other hardware:
Local run, not included in CI.
link
Other options
(1) we can set
bdimy
to dynamic, then no need to usebdimz
and can save one split. Then, to use warp reduction, needs to padbdimy
and ensurebdimx * padded_bdimy % 32 == 0
. Aslo needs to revise the outer reduction part of the kernel, it was assumingbdimy
is static.