Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TMA support for circular buffering pass #2833

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

Add TMA support for circular buffering pass #2833

wants to merge 16 commits into from

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Aug 22, 2024

Summary

This PR adds support for TMA circular buffering. It is stacked on #2824 and #2825.
Tracking branch: #2773

Description

  • In the circular buffer pass, clone operations to create the pre-prologue, prologue, main, epilogue, and post-epilogue for-loops.
  • Pre-Prologue allocates share memory and initializes mbarriers.
  • Prologue copies only the load operations.
  • Main loop copies the load and computation operations and adds arrive_expected_tx for next stage and mbarrier_wait for current stage.
  • Epilogue copies only the computation operations and adds mbarrier_wait for remaining stages in the pipeline.
  • Post-Epilogue invalidated mbarriers.

Lowering Details

Description of changes in lowering passes.

  • Prologue, Main, and Epilogue loops are created by TmaCircularBufferLoopCloner which is a child class of CircularBufferLoopCloner.
  • PrePrologue and PostEpilogue loops are created by createCpAsyncBulkFixtures.
  • The cuTensorMapEncodeTiled restricts the size of each box dimension to be <= 256. You need to launch multiple load operations to load larger tiles.
  • We only allocate mbarriers for each stage, so the expected_transaction bytes is multiplied by the number of TMA loads per stage.
  • The for-loop cloner must account for the nested for-loop structure used to launch multiple TMA loads before adding the mbarrier_wait for the stage.

Loop Structure

Description of for-loop structure for circular buffering.

Overview Circular Buffer Structure:

Pre-prologue loop:

  • Allocate shared memory for mbarriers and mbarrier tokens
  • Initialize mbarrier for all stages

Prologue loop:

  • if selected_thread:
    • Issue cp async bulks for all but last stage

Main loop:

  • if selected_thread:
    • Issue next cp async bulk for available stage
  • All threads wait until tma operation arrives
  • Copy body without
    • shared memory allocations
    • mbarrier_init exprs
    • mbarrier_inval exprs

Epilogue loop:

  • All threads wait until tma operation arrives
  • Copy body without
    • shared memory allocations
    • issuing cp async bulk operations
    • mbarrier_init exprs
    • mbarrier_inval exprs

Post-epilogue loop:

  • if selected_thread:
  • Invalidated mbarrier for all stages
Detailed Pseudo-Code:
constexpr int64_t warp_size = 32;
bool first_warp = threadIdx.x < warp_size && threadIdx.y == 0 && threadIdx.z == 0;

Pre-Prologue loop:

__shared__ __mbarrier_t barriers[num_stages];
__shared__ __mbarrier_token_t tokens[num_stages];
if (first_warp && hopper::electSync()) {
  for (int64_t loop_index : irange(stages)) {
    mbarrier_init(mbarrier[loop_index], number_of_arrival_threads);
  }
}

Prologue loop:

for (int64_t loop_index : irange(stages-1)) {
  if (first_warp && hopper::electSync()) {
    tokens[loop_index] = mbarrier::arriveExpectTx(mbarrier[loop_index]);
    cpAsyncBulk(mbarriers[loop_index], ...);
  } else {
    token[load_stage] = mbarrier::arrive(mbarrier[load_stage]);
  }
}

Main loop:

for (int64_t loop_index : irange(N-(stages-1))) {
  current_stage = loop_index % stage_depth
  load_stage = (loop_index + (stage_depth - 1)) % stage_depth)
  if (first_warp && hopper::electSync()) {
    token[load_stage] =
      mbarrier::arriveExpectTx(mbarrier[load_stage], expected_transaction_size);
    cpAsyncBulk(mbarrier[load_stage], ...);
  } else {
    token[load_stage] = mbarrier::arrive(mbarrier[load_stage]);
  }
  mbarrier::wait(token[current_stage]);

  // Clone remaining operations
}

Epilogue loop:

for (int64_t loop_index : irange(N-(stages-1), N)) {
  current_stage = loop_index % stage_depth
  mbarrier::wait(token[current_stage]);

  // Clone remaining operations
}

Post-Epilogue loop:

if (first_warp && hopper::electSync()) {
  for (int64_t loop_index : irange(stages)) {
    mbarrier_inval(mbarrier[loop_index]);
  }
}

Testing Setup

  • 2 to 4 pipeline stages.
  • (128, 500, 1024) outer dimension.
  • (128, 1024) inner dimension.
  1. Single Dim including Unroll and Unswitch parallelizations.
  2. Multiple Dim
  3. Pointwise
  4. Reduction
  5. InnerPersistent
  6. Matmul

@rdspring1 rdspring1 changed the title Add TMA support for circular buffering pass and testing Add TMA support for circular buffering pass Aug 22, 2024
@csarofeen
Copy link
Collaborator

Awesome, detailed PR description. Thank you.

* Add support for Hopper::electSync
* Create ElectSync PredicateType
* Make mbarrier synchronous
  * mbarrier waits for all threads in CTA
  * All threads issues arriveExpectTx to get mbarrier_token
Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments from a first pass. I haven't looked at tests yet.

csrc/device_lower/pass/allocation.cpp Outdated Show resolved Hide resolved
csrc/device_lower/pass/circular_buffer.cpp Outdated Show resolved Hide resolved
// mbarrier_inval(mbarrier[loop_index]);
// }
// }
//
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement to the comment at line 34.

csrc/device_lower/pass/circular_buffer.cpp Outdated Show resolved Hide resolved
csrc/device_lower/pass/circular_buffer.cpp Outdated Show resolved Hide resolved
csrc/device_lower/pass/circular_buffer.cpp Outdated Show resolved Hide resolved
csrc/executor.cpp Outdated Show resolved Hide resolved
@@ -209,6 +209,27 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
// here.
return IrBuilder::create<Val>(true, DataType::Bool);
}
case PredicateType::ElectSync: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a separate PredicateType::ElectSync predicate type? Should we just use whatever the original predicate type it has, and if the conditional happen to be tidx == 0 && tidy == 0 && tidz == 0, we convert it to the elec_sync() conditional? What do you think? @naoyam

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I separated the ElectSync PredicateType changes into #2923.

@rdspring1
Copy link
Collaborator Author

!build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants