Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use aligned array for iter grouped reduction inputs #2934

Merged
merged 21 commits into from
Oct 3, 2024

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Sep 11, 2024

Fix #2930

Using aligned array of registers when needs to vectorized data transfer between registers and shared memory.

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

Failed tests do not seem to be related to this PR. thunder.tests.test_grad.test_vjp_correctness_sdpa_manual_grad_forward_scaled_dot_product_attention_nvfuser_cuda_thunder.dtypes.bfloat16

@liqiangxl liqiangxl marked this pull request as ready for review September 16, 2024 17:45
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only about codegen.cpp for now

csrc/codegen.cpp Show resolved Hide resolved
csrc/codegen.cpp Show resolved Hide resolved
csrc/codegen.cpp Outdated
ir_utils::isConsumedByIterGroupedReduction(tv)) {
vect_factor = kernel_->summary().num_grouped_iterations;
}
if (vect_factor > 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this be 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is initilized to 0, so stays at 0 when the tv is not vectorized with gmem or smem.

csrc/codegen.cpp Outdated
} else if (
kernel_->summary().num_grouped_iterations > 1 &&
ir_utils::isConsumedByIterGroupedReduction(tv)) {
vect_factor = kernel_->summary().num_grouped_iterations;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, Kernel should already have the correct vectorization factor and CudaCodeGen should just be a straightforward printer. Is there any specific reason for this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel only has vectorized factor for global memory access, not for shared memory access.I added comments for clarity.

Should use aligned array of registers when:
(1) vectorized ld/st with global memory, tv exists in kernel summary vectorized_accesses.
(2) vectorized ld/st with shared memory, tv is input to iteration grouped reduction and vectorized in runtime function.

Another option, is adding these tvs with vectorized shared memory access to kernel_->summary().vectorized_accesses.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange to me that the kernel summary only has info about global memory tensors but not shared memory. Is there any reason for that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange to me that the kernel summary only has info about global memory tensors but not shared memory. Is there any reason for that?

To be more precise, should be kernel summary doesn't have info about vectorized access in runtime functions, e.g. blockIterGroupedYdimReduce()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that. Does that matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I mean this statement is not true. Kernel summary has info about vectorized r/w for both global and shared memories, if these vectorized r/w were generated by the scheduler.
It does't have the info about vectorized r/w implemented in the runtime functions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that. I'm asking why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Becase it doesn't check the runtime functions. These runtime functions are just strings when generating kernel.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't necessary to do so, right? Your PR doesn't do that either. You check some condition in CudaCodeGen. All I'm asking is why it cannot be done when generating the kernel summary. Let me say this again:

In general, Kernel should already have the correct vectorization factor and CudaCodeGen should just be a straightforward printer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are suggesting doing option-2?

Another option, is adding these tvs with vectorized shared memory access to kernel_->summary().vectorized_accesses.

No problem. we can do that.

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build --diff

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Sep 27, 2024

Revised to use option-2, where the tvs used in iter gouped reductions are added to kernel_->summary().vectorized_accesses.
Step-1, mark ParallelType::Group as a special vectorization.

      // ParallelType::Group is used for both reduction & normalization.
      // When used to group iteration dims of outer reduction tvs, it has
      // vectorized access to shared memory and global memory.
      if (ptype == ParallelType::Group) {
        auto def = tv->definition();
        auto grop = dynamic_cast<GroupedReductionOp*>(def);
        if (grop && (!grop->isAllreduce())) {
          has_grouped_vectorize_dim = true;
        }
      }

Step-2, call VectorizeValidator where tv and its producer is added to vectorizedSetInfo and vectorizedAccesses

   if (has_vectorize_dim || has_misaligned_vectorize_dim ||
       has_grouped_vectorize_dim) {
     VectorizeValidator::validate(tv);
   }

@liqiangxl
Copy link
Collaborator Author

!build --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff

@liqiangxl
Copy link
Collaborator Author

!build --diff

@liqiangxl
Copy link
Collaborator Author

@naoyam , I think it's ready for another round of review. codediff is expected, distributd failed tests are not related to this PR.

csrc/codegen.cpp Outdated
@@ -2972,6 +2960,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
aligned_array_of_regs_.insert(tv);
}
}
// tv is aligned if alias is aligned
Copy link
Collaborator

@naoyam naoyam Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what we are doing here, not just the change by this PR. When we have an Allocate node, does it automatically mean it's an Array allocation if it's an alias of another tensor but with a different type? It seems that's what is indicated by line 2956 since it does reinterpret_cast to an Array type. Is this really safe? Don't we need to check the type of the original allocation of the alias tensor?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that was originally introduced in #665. What do you think, @jacobhinkle?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really safe? Don't we need to check the type of the original allocation of the alias tensor?

Isn't that safe if the two types have the same sizeof? That is what #665 is doing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my question was incomplete. If the original tensor is defined as a non-aligned tensor, is it safe to reinterpret-cast it to an aligned array type? Isn't that what could happen?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right if the original tensor had a different alignment from the new tensor. So the danger here would be for example if the original tensor had vectorized access of width 2 then the second tensor which tries to alias it has same sizeof(dtype) but it has a vectorized access of width 8. I didn't think about that in #665 but yes I think we should either guard against that when we're setting up the alias or accomodate the alias vectorized accesses when we codegen the original allocation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, as far as I can see, the original tensor doesn't even seem to be guaranteed to be aligned at any size.

This doesn't need to be fixed in this PR, but please create an issue.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I forgot about this since I haven't looked at the aliasing code in a while, but we already do this analysis to guarantee the alias vectorizations are at most the same width as the original:

// Vectorized allocations require correct alignment so if [this_tv]
// is vectorized, the [reuse_tv] must be vectorized with the same
// or smaller factor.
// No need to check shared memory since it is always aligned to 16
// Bytes which is also the maximum vectorization width.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Glad it'a a false alarm.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I wonder if it should specify the alignment size with reinterpret_cast.

<< " = *reinterpret_cast<Array<" << buffer_dtype << ", "
<< genInline(size) << ">*>(&" << genVariableName(alias_tv)

If I read this correctly, it just uses an Array type with no alignment requirement, for example, Array<float, 8>. The default alignment size is 1, so it seems this would tell the compiler that we are using a non-aligned address with vector loads and stores. The address is indeed aligned, so it should have no problem unless the compiler does something when a given address is marked not aligned properly. I think it'd be safer to use a proper aligned type always even when we know it's properly aligned.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's probably not a bad idea to add it: #3084

@@ -299,7 +299,8 @@ std::unique_ptr<caching::VectorizedTensorInfo> getVectorizedTensorValidationInfo

auto vector_dim = vector_info.vectorized_loop_id;
const auto is_aligned =
vector_dim->getParallelType() == ParallelType::Vectorize;
vector_dim->getParallelType() == ParallelType::Vectorize ||
vector_dim->getParallelType() == ParallelType::Group;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because grouped reduction tv is added to vectorized_set_info, if not aligned, must be fusion input or output.
If don't directly using VectorizeValidator::validate(tv), won't be added to vectorized_set_info and this change is not required.

@naoyam
Copy link
Collaborator

naoyam commented Oct 1, 2024

I'm not sure why we also need to change the executor as well as the vectorization validator. I thought what we are missing is using aligned arrays inside some of the device functions and they are just local temporary arrays, so those are not something we would need to validate, right? What are we validating then?

csrc/codegen.cpp Outdated
@@ -2972,6 +2960,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
aligned_array_of_regs_.insert(tv);
}
}
// tv is aligned if alias is aligned
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this? I understand this is just a naming difference, but whether the original allocation is aligned or not shouldn't matter for the aliasing tensor, right? For example, this tv can be a tensor with no alignment requirement, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added to fix the test failure in CombinedSchedulerTest.LayerNormBackward/dtype_float_batch_216_hidden_65536where we have

Array<float, 4, 4> T32;
auto& T29 = T32;

compiler treats T29 as aligned array instead of regular array, when passing T29 to a runtime function, should use T29.array instead of T29.
So if the original allocation is aligned, its aliasing tv should also be aligned due to auto type derivation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, T29 does't have to be aligned. We can use dynamic cast to remove this. But then we need to change the code of alias allocation.

@liqiangxl
Copy link
Collaborator Author

I'm not sure why we also need to change the executor as well as the vectorization validator. I thought what we are missing is using aligned arrays inside some of the device functions and they are just local temporary arrays, so those are not something we would need to validate, right? What are we validating then?

You are right. There is no need to validate. I was using VectorizeValidator::validate(tv) becuase this function not only validate vectorization, it also collects the vectorization info and stores in GpuLower::current()->vectorizedAccesses(). We need this info to correctly define the aligned array of registers in codegen.
I'll revise to directly add this info instead of reusing the overkill function VectorizeValidator::validate(tv)

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl
Copy link
Collaborator Author

!build

@@ -598,7 +598,25 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) {
"Only allow misaligned vectorization between global and local memory.");
has_misaligned_vectorize_dim = true;
}

// ParallelType::Group is used for both reduction & normalization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really necessary? Shared memory is always aligned. If the producer is in global memory, then that should be a temporary work buffer, so it should be always a contiguous, aligned buffer.

Copy link
Collaborator Author

@liqiangxl liqiangxl Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean between register and shared memory, so we need to ensure the registers are aligned when doing vectorized read/write. Let me create an example.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, T13 = GroupedReductionOf(T14), T14 (registers) should be aligned becase runtime funciton uses vectorized copy from T14(registers) to shared memory.

T13_l_float[ iblockIdx.x68{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 1) )}, ithreadIdx.x67{blockDim.x}, iUS69{1}, iG65{4}, rthreadIdx.y62{blockDim.y} ] ca_pos( 3 ) produce_pos( 2 )
   = reduction( T14_l_float[ iblockIdx.x54{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 1) )}, ithreadIdx.x53{blockDim.x}, rS60{( ceilDiv(( ceilDiv(( ceilDiv(i1, blockDim.y) ), 2) ), 1) )}rf, iUS55{1}, iS51{4}, ithreadIdx.y57{blockDim.y}rf, rUS61{1}rf, rS59{2}rf ] ca_pos( 2 ) produce_pos( 8 ), op = add, initial value = float(0), allreduce = false )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I got it. The comment seems wrong, though.

since they are register arrays defined in runtime function

This producer_tv is not defined in the runtime functions, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. producer_tv is not defined in the runtime functions, it is just passed to the runtime function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment is revised as

      // ParallelType::Group is used for both reduction and normalization.
      // In grouped outer reduction, the runtime function uses vectorized data
      // transfers between registers and shared memory. The producer tensor is
      // stored in registers and loaded into shared memory in a vectorized
      // manner, so we add it to the vectorizedAccesses map to ensure register
      // alignment.

jacobhinkle added a commit that referenced this pull request Oct 2, 2024
See #2934 (comment)

PR #665 allowed us to re-use allocations that have different dtypes. We
already check that our aliased tensors do not have vectorized accesses
larger than those of the original tensors. However, when we have
different dtypes we `reinterpret_cast` it to a different `Array` type.
Previously we did not specify any alignment in that type's template
args, meaning it assumed an alignment of size 1. Since the actual
addresses will all still be aligned this does not caused misaligned
accesses at runtime. This PR sets the template arg for alignment to be
that of the vectorized access width for the alias tensor, so that the
compiler could hypothetically do some optimizations knowing the address
is aligned.
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix.

@liqiangxl
Copy link
Collaborator Author

!build

@liqiangxl liqiangxl merged commit 94d4b70 into main Oct 3, 2024
11 of 12 checks passed
@liqiangxl liqiangxl deleted the llu/aligned_reg_array branch October 3, 2024 12:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FusionReductionWithTrivialReduction_CUDA fails with compute-sanitizer
3 participants