Skip to content

Commit

Permalink
omg it's working
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball committed Feb 5, 2025
1 parent 011a99d commit 144771b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 66 deletions.
10 changes: 6 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,14 @@ CoarseSchedule scheduleKeyOps(scf::ForOp forOp,
auto ifOp = dyn_cast<scf::IfOp>(op);
if (!ifOp)
continue;
// If the `scf.if` op itself is a latency op, skip it.
if (opLatency.contains(ifOp))
continue;
// Ensure this does not create scheduling conflicts by ensuring the forward
// slice of the `scf.if` does not contain ops that are already scheduled.
// slice of the `scf.if` does not contain ops that are already scheduled, as
// this will cause the `scf.if` to be scheduled before its dependents.
SetVector<Operation *> slice;
ForwardSliceOptions opts;
opts.inclusive = true;
getForwardSlice(ifOp, &slice, opts);
getForwardSlice(ifOp, &slice);
if (llvm::any_of(slice, [&](Operation *op) { return opToStage.count(op); }))
continue;
schedule.insert(ifOp, stage, epilogue);
Expand Down
79 changes: 50 additions & 29 deletions manual.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,73 @@

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>

!tensor_ptr = tensor<128x128x!tt.ptr<f32>, #blocked>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg4: tensor<128x128x!tt.ptr<f32>, #blocked>) -> tensor<128x128xf32, #blocked1> {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg4: tensor<128x128x!tt.ptr<f32>, #blocked>) {
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%cst_1 = arith.constant dense<1> : tensor<128x128xi32, #blocked>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32

%buf = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
%buf1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>

%0:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32, %arg7 = %cst, %arg8 = %arg3, %arg9 = %arg4) -> (i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>) : i32 {
%1 = arith.cmpi eq, %arg6, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : i32
%2 = arith.addi %arg6, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 1 : i32} : i32
%0:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32,
%arg9 = %arg4,
%arg10 = %arg4, %arg11 = %arg4) -> (i32, !tensor_ptr, !tensor_ptr, !tensor_ptr) : i32 {
%1 = arith.cmpi eq, %arg6, %c0_i32 : i32
%2 = arith.addi %arg6, %c1_i32 : i32

scf.if %1 {
%tok = ttg.async_copy_global_to_local %arg8, %buf : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
//"something"(%1) {tt_latency = 2 : i32} : (i1) -> ()
%tok0 = scf.if %1 -> !ttg.async.token {
%tok = ttg.async_copy_global_to_local %arg10, %buf : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
%tok2 = ttg.async_commit_group %tok
scf.yield
scf.yield %tok2 : !ttg.async.token
} else {
scf.yield
} {loop.cluster = 0 : i32, loop.stage = 1 : i32}
%undef = ub.poison : !ttg.async.token
scf.yield %undef : !ttg.async.token
} {tt_latency = 2 : i32}

//%3 = "something"(%1) : (i1) -> tensor<128x128xf32, #blocked>
%3 = scf.if %1 -> (tensor<128x128xf32, #blocked>) {
ttg.async_wait {num = 0 : i32}
%11 = ttg.local_load %buf : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
%tt = ttg.async_wait %tok0 {num = 0 : i32}
%11 = ttg.local_load %buf token %tt : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
scf.yield %11 : tensor<128x128xf32, #blocked>
} else {
scf.yield %cst_0 : tensor<128x128xf32, #blocked>
} {loop.cluster = 0 : i32, loop.stage = 2 : i32}

%4 = tt.load %arg9 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>
%5 = ttg.convert_layout %3 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
%6 = ttg.convert_layout %4 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>
%7 = tt.dot %5, %6, %arg7, inputPrecision = tf32 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x128xf32, #blocked1>
%8 = tt.addptr %arg8, %cst_1 {loop.cluster = 1 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%9 = tt.addptr %arg9, %cst_1 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%10 = arith.cmpi eq, %2, %arg1 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32
}

%9 = tt.addptr %arg9, %cst_1: tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%15 = tt.addptr %arg10, %cst_1: tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%16 = tt.addptr %arg11, %cst_1: tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>


%4 = tt.load %arg9 {tt_latency = 2 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>

%5 = arith.addf %4, %3 : tensor<128x128xf32, #blocked>


%10 = arith.cmpi eq, %arg6, %arg1 : i32
%tok1 = scf.if %10 -> !ttg.async.token {
%tok = ttg.async_copy_global_to_local %arg11, %buf1 : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
%tok2 = ttg.async_commit_group %tok
scf.yield %tok2 : !ttg.async.token
} else {
%undef = ub.poison : !ttg.async.token
scf.yield %undef : !ttg.async.token
} {tt_latency = 2 : i32}

scf.if %10 {
%11 = tt.load %arg9 : tensor<128x128x!tt.ptr<f32>, #blocked>
%12 = ttg.convert_layout %7 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked>
%13 = arith.addf %11, %12 : tensor<128x128xf32, #blocked>
tt.store %arg9, %13 : tensor<128x128x!tt.ptr<f32>, #blocked>
} {loop.cluster = 5 : i32, loop.stage = 2 : i32}
scf.yield %2, %7, %8, %9 : i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>
}
tt.return %0#1 : tensor<128x128xf32, #blocked1>
%kk = ttg.async_wait %tok1 {num = 0 : i32}
%11 = ttg.local_load %buf1 token %kk : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
%x = arith.addf %11, %5 : tensor<128x128xf32, #blocked>
tt.store %arg4, %x : tensor<128x128x!tt.ptr<f32>, #blocked>
}
scf.yield %2, %9, %15, %16 : i32, tensor<128x128x!tt.ptr<f32>, #blocked>, !tensor_ptr, !tensor_ptr
} {tt.num_stages = 3 : i32}
tt.return
}
}

75 changes: 42 additions & 33 deletions test.mlir
Original file line number Diff line number Diff line change
@@ -1,46 +1,55 @@
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg4: tensor<128x128x!tt.ptr<f32>, #blocked>) -> tensor<128x128xf32, #blocked1> {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%cst_1 = arith.constant dense<1> : tensor<128x128xi32, #blocked>
tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr<f32>, #blocked>, %arg4: tensor<128x128x!tt.ptr<f32>, #blocked>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%cst_0 = arith.constant dense<1> : tensor<128x128xi32, #blocked>
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
%1:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32, %arg7 = %cst, %arg8 = %arg3, %arg9 = %arg4) -> (i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>) : i32 {
%2 = arith.cmpi eq, %arg6, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : i32
%3 = arith.addi %arg6, %c1_i32 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32
scf.if %2 {
%12 = ttg.async_copy_global_to_local %arg8, %0 : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
%13 = ttg.async_commit_group %12
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable>
%2:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32, %arg7 = %arg4, %arg8 = %arg4, %arg9 = %arg4) -> (i32, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>) : i32 {
%3 = arith.cmpi eq, %arg6, %c0_i32 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32
%4 = arith.addi %arg6, %c1_i32 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : i32
%5 = scf.if %3 -> (!ttg.async.token) {
%14 = ttg.async_copy_global_to_local %arg8, %0 : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
%15 = ttg.async_commit_group %14
scf.yield %15 : !ttg.async.token
} else {
} {loop.cluster = 6 : i32, loop.stage = 2 : i32}
%4 = scf.if %2 -> (tensor<128x128xf32, #blocked>) {
%12 = ttg.async_wait {num = 0 : i32}
%13 = ttg.local_load %0 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
scf.yield %13 : tensor<128x128xf32, #blocked>
%14 = ub.poison : !ttg.async.token
scf.yield %14 : !ttg.async.token
} {loop.cluster = 4 : i32, loop.stage = 0 : i32, tt_latency = 2 : i32}
%6 = scf.if %3 -> (tensor<128x128xf32, #blocked>) {
%14 = ttg.async_wait %5 {num = 0 : i32}
%15 = ttg.local_load %0 token %14 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
scf.yield %15 : tensor<128x128xf32, #blocked>
} else {
scf.yield %cst_0 : tensor<128x128xf32, #blocked>
} {loop.cluster = 0 : i32, loop.stage = 2 : i32}
%5 = tt.load %arg9 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>
%6 = ttg.convert_layout %4 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
%7 = ttg.convert_layout %5 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>>
%8 = tt.dot %6, %7, %arg7, inputPrecision = tf32 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x128xf32, #blocked1>
%9 = tt.addptr %arg8, %cst_1 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%10 = tt.addptr %arg9, %cst_1 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%11 = arith.cmpi eq, %3, %arg1 {loop.cluster = 5 : i32, loop.stage = 2 : i32} : i32
scf.if %11 {
%12 = tt.load %arg9 : tensor<128x128x!tt.ptr<f32>, #blocked>
%13 = ttg.convert_layout %8 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked>
%14 = arith.addf %12, %13 : tensor<128x128xf32, #blocked>
tt.store %arg9, %14 : tensor<128x128x!tt.ptr<f32>, #blocked>
scf.yield %cst : tensor<128x128xf32, #blocked>
} {loop.cluster = 1 : i32, loop.stage = 2 : i32}
%7 = tt.addptr %arg7, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%8 = tt.addptr %arg8, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%9 = tt.addptr %arg9, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128xi32, #blocked>
%10 = tt.load %arg7 {loop.cluster = 4 : i32, loop.stage = 0 : i32, tt_latency = 2 : i32} : tensor<128x128x!tt.ptr<f32>, #blocked>
%11 = arith.addf %10, %6 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked>
%12 = arith.cmpi eq, %arg6, %arg1 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32
%13 = scf.if %12 -> (!ttg.async.token) {
%14 = ttg.async_copy_global_to_local %arg9, %1 : tensor<128x128x!tt.ptr<f32>, #blocked> -> <128x128xf32, #shared, #smem, mutable>
%15 = ttg.async_commit_group %14
scf.yield %15 : !ttg.async.token
} else {
%14 = ub.poison : !ttg.async.token
scf.yield %14 : !ttg.async.token
} {loop.cluster = 4 : i32, loop.stage = 0 : i32, tt_latency = 2 : i32}
scf.if %12 {
%14 = ttg.async_wait %13 {num = 0 : i32}
%15 = ttg.local_load %1 token %14 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked>
%16 = arith.addf %15, %11 : tensor<128x128xf32, #blocked>
tt.store %arg4, %16 : tensor<128x128x!tt.ptr<f32>, #blocked>
} {loop.cluster = 5 : i32, loop.stage = 2 : i32}
scf.yield %3, %8, %9, %10 : i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>
}
tt.return %1#1 : tensor<128x128xf32, #blocked1>
scf.yield %4, %7, %8, %9 : i32, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>, tensor<128x128x!tt.ptr<f32>, #blocked>
} {tt.num_stages = 3 : i32}
tt.return
}
}

0 comments on commit 144771b

Please sign in to comment.