From 00bbd3614629e3faa930e00fa3bf9164999d0db9 Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 6 Feb 2025 13:07:40 -0800 Subject: [PATCH] merge --- case.mlir | 44 -- .../Transforms/Pipeliner/AssignLatencies.cpp | 25 +- .../Pipeliner/MatmulLoopPipeline.cpp | 57 +- manual.mlir | 86 --- real.mlir | 596 ------------------ test.mlir | 44 -- test1.mlir | 219 ------- 7 files changed, 75 insertions(+), 996 deletions(-) delete mode 100644 case.mlir delete mode 100644 manual.mlir delete mode 100644 real.mlir delete mode 100644 test.mlir delete mode 100644 test1.mlir diff --git a/case.mlir b/case.mlir deleted file mode 100644 index 346570089a59..000000000000 --- a/case.mlir +++ /dev/null @@ -1,44 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr, #blocked>, %arg4: tensor<128x128x!tt.ptr, #blocked>) -> tensor<128x128xf32, #blocked1> { - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked1> - %cst1 = arith.constant dense<0.0> : tensor<128x128xf32, #blocked> - %cst_0 = arith.constant dense<1> : tensor<128x128xi32, #blocked> - %c0 = arith.constant 0 : i32 - %c1 = arith.constant 1 : i32 - - %0:4 = scf.for %arg6 = %arg0 to %arg1 step %arg2 iter_args(%k = %c0, %arg7 = %cst, %arg8 = %arg3, %arg9 = %arg4) -> (i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>) : i32 { - %k_cond = arith.cmpi eq, %k, %c0 : i32 - %next_k = arith.addi %k, %c1 : i32 - %3 = scf.if %k_cond -> tensor<128x128xf32, #blocked> { - %res = tt.load %arg8 : tensor<128x128x!tt.ptr, #blocked> - scf.yield %res : tensor<128x128xf32, #blocked> - } else { - scf.yield %cst1 : tensor<128x128xf32, #blocked> - } - - %4 = tt.load %arg9 : tensor<128x128x!tt.ptr, #blocked> - %5 = ttg.convert_layout %3 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> - %6 = ttg.convert_layout %4 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> - %7 = tt.dot %5, %6, %arg7 {inputPrecision = 0 : i32} : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x128xf32, #blocked1> - %8 = tt.addptr %arg8, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %9 = tt.addptr %arg9, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - - %k_cond_2 = arith.cmpi eq, %next_k, %arg1 : i32 - scf.if %k_cond_2 { - %res = tt.load %arg9 : tensor<128x128x!tt.ptr, #blocked> - %rhs = ttg.convert_layout %7 : tensor<128x128xf32, #blocked1> -> tensor<128x128xf32, #blocked> - %sum = arith.addf %res, %rhs : tensor<128x128xf32, #blocked> - tt.store %arg9, %sum : tensor<128x128x!tt.ptr, #blocked> - } - - - scf.yield %next_k, %7, %8, %9 : i32, tensor<128x128xf32, #blocked1>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked> - } - tt.return %0#1 : tensor<128x128xf32, #blocked1> -} - -} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 97fadf5efb8e..5271667dcede 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -206,7 +206,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, // If the loop has numStages attribute, also consider pipelining other loads // that are not directly used by dot ops. - if (pipelineWithoutDot && !seenDot) { + if (pipelineWithoutDot /*&& !seenDot*/) { for (Operation &op : forOp.getBody()->without_terminator()) { if (!isa(op)) @@ -263,6 +263,16 @@ DenseMap assignLatencies(ModuleOp moduleOp, DenseMap opLatency; for (auto forOp : loops) { + for (auto ifOp : forOp.getBody()->getOps()) { + auto isLoad = [&](Operation &op) { + return isa(op); + }; + if (llvm::any_of(*ifOp.thenBlock(), isLoad) || + (ifOp.elseBlock() && llvm::any_of(*ifOp.elseBlock(), isLoad))) + ifOp->setAttr("ttg.conditional_load", UnitAttr::get(ifOp.getContext())); + } + if (hasLatenciesAssigned(forOp)) { assignUserProvidedLatencies(forOp, opLatency); continue; @@ -289,10 +299,21 @@ DenseMap assignLatencies(ModuleOp moduleOp, ++iter; } + int usedStages = 0; + for (Operation *loadOp : + llvm::to_vector(llvm::make_first_range(loadOpToIndLevel))) { + if (isa(loadOp)) { + opLatency[loadOp] = 1; + usedStages = std::max(usedStages, loadOpToIndLevel[loadOp]); + loadOpToIndLevel.erase(loadOp); + } + } + // Calculate the stage distance between applicable loads. auto vals = llvm::make_second_range(loadOpToIndLevel); int maxIndirectionLevel = vals.empty() ? 0 : *llvm::max_element(vals); - unsigned loadLatency = (numStages - 1) / (maxIndirectionLevel + 1); + unsigned loadLatency = + (numStages - 1 - usedStages) / (maxIndirectionLevel + 1); for (auto [loadOp, dist] : loadOpToIndLevel) { opLatency[loadOp] = loadLatency; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 31e9b47975fa..1c25b6dc378d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -53,7 +53,7 @@ struct LoadInfo { << " sharedEncoding: " << sharedEncoding << "\n" << " blockedEncoding: " << blockedEncoding << "\n" << " isMMAv3Shared: " << isMMAv3Shared << "\n" - << " isMMAv3Registers: " << isMMAv3Registers << "\n" + << " isMMAv5Scale: " << isMMAv5Scale << "\n" << " distToUse: " << distToUse << "\n" << " usedByDot: " << usedByDot << "\n"; } @@ -112,8 +112,14 @@ static bool sameStageCluster(Operation *op1, Operation *op2) { // Return user of a loadOp with the lowest stage, if two users have the // same stage, return the user with lower cluster. static Operation *getFirstUseOfPipelinedLoad(Operation *loadOp) { + if (auto condOp = dyn_cast(loadOp->getParentOp())) + loadOp = condOp; + Operation *firstUser = nullptr; for (Operation *user : loadOp->getUsers()) { + if (isa(user)) + continue; + // Climb up to the containing op in the same block as the load. while (user->getBlock() != loadOp->getBlock()) user = user->getParentOp(); @@ -522,6 +528,8 @@ assignMemoryLayouts(scf::ForOp &forOp, // as a pipelined load. auto [sLoad, _cLoad] = tt::getStageCluster(&op); Operation *firstUse = getFirstUseOfPipelinedLoad(&op); + if (!firstUse) + continue; LDBG("first use for load " << op); LDBG(" - use: " << *firstUse); auto firstUseStageCluster = tt::maybeGetStageCluster(firstUse); @@ -535,6 +543,9 @@ assignMemoryLayouts(scf::ForOp &forOp, if (auto condOp = dyn_cast(op)) { condOp->setAttr("ttg.pipelined_load", UnitAttr::get(op.getContext())); for (Operation *op : condOp.getLoads()) { + if (!isa(op->getResultTypes().front())) + continue; + loadsToPipeline.insert(op); LoadInfo &loadInfo = loadToInfo[op]; loadInfo.distToUse = distToUse; @@ -998,6 +1009,23 @@ static SmallVector splitIntoClusters(Block *block) { return result; } +// To model an "undef" value, i.e. a value that is known to never be read on +// live code paths, create a zero-valued constant where possible, otherwise use +// a poison value. PTXAS appears to generate better code with zeros compared to +// poison values. +static Value createPoisonOrZero(ImplicitLocOpBuilder &b, Type type) { + Type elTy = getElementTypeOrSelf(type); + if (!elTy.isIntOrIndexOrFloat() || + (!isa(type) && type != elTy)) + return b.create(type); + + TypedAttr attr = isa(elTy) ? TypedAttr(b.getFloatAttr(elTy, 0)) + : b.getIntegerAttr(elTy, 0); + if (auto tensor = dyn_cast(type)) + attr = SplatElementsAttr::get(tensor, attr); + return b.create(attr); +} + static void splitIntoClusters(scf::IfOp ifOp) { // First partition the regions into clusters. SmallVector thenClusters = splitIntoClusters(ifOp.thenBlock()); @@ -1026,7 +1054,7 @@ static void splitIntoClusters(scf::IfOp ifOp) { b.createBlock(&otherRegion); SmallVector undefs; for (Type type : cluster.getOutputTypes()) - undefs.push_back(b.create(type)); + undefs.push_back(createPoisonOrZero(b, type)); b.create(undefs); }; @@ -1041,9 +1069,24 @@ static void splitIntoClusters(scf::IfOp ifOp) { isThen ? clusterIf.getElseRegion() : clusterIf.getThenRegion()); } - // Set the leftover select to the stage and cluster of the first use. - auto [stage, cluster] = tt::getStageCluster(getFirstUseOfPipelinedLoad(ifOp)); - tt::setStageCluster(ifOp, stage, cluster); + // Break up the final if. + for (auto [trueVal, falseVal, result] : + llvm::zip(ifOp.thenYield().getOperands(), ifOp.elseYield().getOperands(), + ifOp.getResults())) { + SmallVector> stageClusters; + if (Operation *op = trueVal.getDefiningOp()) + stageClusters.push_back(tt::getStageCluster(op)); + if (Operation *op = falseVal.getDefiningOp()) + stageClusters.push_back(tt::getStageCluster(op)); + auto [stage, cluster] = stageClusters.empty() + ? tt::getStageCluster(ifOp) + : *llvm::max_element(stageClusters); + auto select = + b.create(ifOp.getCondition(), trueVal, falseVal); + tt::setStageCluster(select, stage, cluster); + result.replaceAllUsesWith(select); + } + ifOp.erase(); } static void processConditionalLoads(scf::ForOp forOp, int numStages) { @@ -1071,6 +1114,9 @@ static void processConditionalLoads(scf::ForOp forOp, int numStages) { auto &firstUseCluster = clusters[clusterForFirstUse]; for (Operation *loadOp : condOp.getLoads()) { + if (!isa(loadOp->getResultTypes().front())) + continue; + nestedSchedule.insert(loadOp, stage, loadCluster); nestedSchedule.insertDepsOfOp(loadOp, stage, loadCluster, /*includeArg=*/false); @@ -1284,6 +1330,7 @@ createAsyncOps(scf::ForOp &forOp, if (condOp->removeAttr("ttg.pipelined_load")) splitIntoClusters(condOp); } + assert(succeeded(mlir::verify(forOp)) && "splitting produced invalid IR"); tt::CoarseSchedule coarseSchedule(numStages); coarseSchedule.deSerialize(forOp); diff --git a/manual.mlir b/manual.mlir deleted file mode 100644 index de6ecbf0299d..000000000000 --- a/manual.mlir +++ /dev/null @@ -1,86 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> -#smem = #ttg.shared_memory - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr, #blocked>, %arg4: tensor<128x128x!tt.ptr, #blocked>) { - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %cst_0 = arith.constant dense<1> : tensor<128x128xi32, #blocked> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - - %2:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32, %arg7 = %arg4, %arg8 = %arg4, %arg9 = %arg4) -> (i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>) : i32 { - %3 = arith.cmpi eq, %arg6, %c0_i32 : i32 - %4 = arith.addi %arg6, %c1_i32 : i32 - - //%5 = scf.if %3 -> (!ttg.async.token) { - // %14 = ttg.async_copy_global_to_local %arg8, %0 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable> - // %15 = ttg.async_commit_group %14 - // scf.yield %15 : !ttg.async.token - //} else { - // %14 = ub.poison : !ttg.async.token - // scf.yield %14 : !ttg.async.token - //} - - //%6 = scf.if %3 -> (tensor<128x128xf32, #blocked>) { - // %14 = ttg.async_wait %5 {num = 0 : i32} - // %15 = ttg.local_load %0 token %14 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked> - // scf.yield %15 : tensor<128x128xf32, #blocked> - //} else { - // scf.yield %cst : tensor<128x128xf32, #blocked> - //} - - %6:2 = scf.if %3 -> (tensor<128x128xf32, #blocked>, i32) { - %backward_dep = tt.addptr %arg8, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %15 = tt.load %backward_dep : tensor<128x128x!tt.ptr, #blocked> - %xxxx = tt.load %arg8 : tensor<128x128x!tt.ptr, #blocked> - %fwd_dep = arith.addf %15, %xxxx : tensor<128x128xf32, #blocked> - %other = arith.addi %arg5, %arg5 : i32 - scf.yield %fwd_dep, %other : tensor<128x128xf32, #blocked>, i32 - } else { - %15 = tt.load %arg8 : tensor<128x128x!tt.ptr, #blocked> - %fwd_dep = arith.addf %15, %cst : tensor<128x128xf32, #blocked> - %other = arith.addi %arg5, %arg5 : i32 - scf.yield %fwd_dep, %other : tensor<128x128xf32, #blocked>, i32 - } {ttg.conditional_load, b} - - %7 = tt.addptr %arg7, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %8 = tt.addptr %arg8, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %9 = tt.addptr %arg9, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - - %10 = tt.load %arg7 : tensor<128x128x!tt.ptr, #blocked> - - %11 = arith.addf %10, %6#0 : tensor<128x128xf32, #blocked> - %12 = arith.cmpi eq, %arg6, %arg1 : i32 - - //%13 = scf.if %12 -> (!ttg.async.token) { - // %14 = ttg.async_copy_global_to_local %arg9, %1 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable> - // %15 = ttg.async_commit_group %14 - // scf.yield %15 : !ttg.async.token - //} else { - // %14 = ub.poison : !ttg.async.token - // scf.yield %14 : !ttg.async.token - //} - - %15 = scf.if %12 -> tensor<128x128xf32, #blocked> { - %xx = tt.load %arg9 : tensor<128x128x!tt.ptr, #blocked> - scf.yield %xx : tensor<128x128xf32, #blocked> - } else { - scf.yield %cst : tensor<128x128xf32, #blocked> - } {ttg.conditional_load, a} - - scf.if %12 { - //%14 = ttg.async_wait %13 {num = 0 : i32} - //%15 = ttg.local_load %1 token %14 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable> -> tensor<128x128xf32, #blocked> - //%15 = tt.load %arg9 : tensor<128x128x!tt.ptr, #blocked> - %16 = arith.addf %15, %11 : tensor<128x128xf32, #blocked> - tt.store %arg4, %16 : tensor<128x128x!tt.ptr, #blocked> - } - scf.yield %4, %7, %8, %9 : i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked> - - } {tt.num_stages = 3 : i32} - tt.return -} - -} diff --git a/real.mlir b/real.mlir deleted file mode 100644 index f63129ae9b5e..000000000000 --- a/real.mlir +++ /dev/null @@ -1,596 +0,0 @@ - -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}> -#loc = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0) -#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> -#shared1 = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> -#shared2 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> -#smem = #ttg.shared_memory -#tmem = #ttng.tensor_memory_encoding -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @_p_matmul_ogs_NNT_bf16xbf16xbf16_128x128x64x1(%arg0: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg1: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg4: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg5: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg6: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg7: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg8: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg9: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg10: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg11: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg12: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg13: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg14: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg15: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg16: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg17: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg18: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg19: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg20: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg21: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg22: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg23: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg24: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg25: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg26: !tt.ptr loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg27: !tt.ptr loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg28: i32 loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0), %arg29: i32 {tt.divisibility = 16 : i32} loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":48:0)) attributes {noinline = false} { - %c2_i64 = arith.constant 2 : i64 loc(#loc1) - %c3_i32 = arith.constant 3 : i32 loc(#loc1) - %c2_i32 = arith.constant 2 : i32 loc(#loc1) - %c4_i32 = arith.constant 4 : i32 loc(#loc1) - %false = arith.constant false loc(#loc1) - %true = arith.constant true loc(#loc1) - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xbf16, #blocked> loc(#loc1) - %c148_i32 = arith.constant 148 : i32 loc(#loc1) - %c1_i32 = arith.constant 1 : i32 loc(#loc1) - %c128_i32 = arith.constant 128 : i32 loc(#loc1) - %c0_i32 = arith.constant 0 : i32 loc(#loc1) - %cst_0 = arith.constant dense<0> : tensor<128xi32, #blocked1> loc(#loc1) - %c-1_i32 = arith.constant -1 : i32 loc(#loc1) - %cst_1 = arith.constant dense<-1> : tensor<128xi32, #blocked1> loc(#loc1) - %c1073741824_i32 = arith.constant 1073741824 : i32 loc(#loc1) - %c65535_i32 = arith.constant 65535 : i32 loc(#loc1) - %c16_i32 = arith.constant 16 : i32 loc(#loc1) - %c64_i32 = arith.constant 64 : i32 loc(#loc1) - %cst_2 = arith.constant dense<1073741825> : tensor<128xi32, #blocked1> loc(#loc1) - %c147_i32 = arith.constant 147 : i32 loc(#loc1) - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128xf32, #blocked1> loc(#loc1) - %c8_i32 = arith.constant 8 : i32 loc(#loc1) - %c63_i32 = arith.constant 63 : i32 loc(#loc1) - %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked2> loc(#loc1) - %0 = tt.load %arg26 : !tt.ptr loc(#loc2) - %1 = arith.subi %arg28, %0 : i32 loc(#loc3) - %2 = tt.get_program_id x : i32 loc(#loc4) - %3 = arith.subi %c147_i32, %2 : i32 loc(#loc5) - %4 = arith.muli %arg28, %arg29 : i32 loc(#loc6) - %5 = arith.muli %arg29, %c8_i32 : i32 loc(#loc109) - %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc9) - %7 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc9) - %8 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #blocked1> loc(#loc9) - %9 = tt.splat %arg21 : i32 -> tensor<128xi32, #blocked1> loc(#loc10) - %10 = tt.splat %arg20 : !tt.ptr -> tensor<128x!tt.ptr, #blocked1> loc(#loc11) - %11 = arith.extsi %arg6 : i32 to i64 loc(#loc12) - %12 = tt.splat %11 : i64 -> tensor<128x1xi64, #blocked> loc(#loc12) - %13 = tt.splat %arg4 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> loc(#loc13) - %14 = tt.splat %arg18 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc14) - scf.for %arg30 = %3 to %4 step %c148_i32 : i32 { - %83 = arith.divsi %arg30, %5 : i32 loc(#loc110) - %84 = arith.muli %83, %c8_i32 : i32 loc(#loc111) - %85 = arith.subi %arg28, %84 : i32 loc(#loc112) - %86 = arith.minsi %85, %c8_i32 : i32 loc(#loc113) - %87 = arith.remsi %arg30, %86 : i32 loc(#loc114) - %88 = arith.addi %84, %87 : i32 loc(#loc115) - %89 = arith.remsi %arg30, %5 : i32 loc(#loc116) - %90 = arith.divsi %89, %86 : i32 loc(#loc117) - %91 = arith.muli %88, %c128_i32 : i32 loc(#loc24) - %92 = tt.splat %91 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc25) - %93 = tt.splat %91 : i32 -> tensor<128xi32, #blocked1> loc(#loc25) - %94 = arith.addi %92, %6 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc25) - %95 = arith.addi %93, %8 : tensor<128xi32, #blocked1> loc(#loc25) - %96 = arith.muli %90, %c128_i32 : i32 loc(#loc26) - %97 = tt.splat %96 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %98 = arith.addi %97, %7 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc27) - %99 = arith.cmpi slt, %95, %9 : tensor<128xi32, #blocked1> loc(#loc10) - %100 = tt.addptr %10, %95 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> loc(#loc11) - %101 = tt.load %100, %99, %cst_0 : tensor<128x!tt.ptr, #blocked1> loc(#loc28) - %102 = arith.extsi %94 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc29) - %103 = tt.expand_dims %102 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> loc(#loc30) - %104 = arith.muli %103, %12 : tensor<128x1xi64, #blocked> loc(#loc12) - %105 = tt.addptr %13, %104 : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi64, #blocked> loc(#loc13) - %106 = tt.expand_dims %98 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> loc(#loc31) - %107 = tt.broadcast %105 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> loc(#loc32) - %108 = tt.broadcast %106 : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> loc(#loc32) - %109 = tt.addptr %107, %108 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> loc(#loc32) - %110 = arith.cmpi slt, %98, %14 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc14) - %111 = arith.cmpi eq, %101, %cst_1 : tensor<128xi32, #blocked1> loc(#loc33) - %112 = ttg.convert_layout %111 : tensor<128xi1, #blocked1> -> tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc34) - %113 = tt.expand_dims %112 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked> loc(#loc35) - %114 = tt.expand_dims %110 {axis = 0 : i32} : tensor<128xi1, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi1, #blocked> loc(#loc36) - %115 = tt.broadcast %113 : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc34) - %116 = tt.broadcast %114 : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> loc(#loc34) - %117 = arith.andi %115, %116 : tensor<128x128xi1, #blocked> loc(#loc34) - tt.store %109, %cst, %117 : tensor<128x128x!tt.ptr, #blocked> loc(#loc37) - } loc(#loc15) - %15 = arith.extsi %arg9 : i32 to i64 loc(#loc38) - %16 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr loc(#loc38) - %17 = arith.muli %15, %c2_i64 : i64 loc(#loc38) - tt.experimental_tensormap_create %16, %arg7, [%c64_i32, %c128_i32], [%arg19, %c0_i32], [%17], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc38) - tt.experimental_tensormap_fenceproxy_acquire %16 : !tt.ptr loc(#loc38) - %18 = tt.reinterpret_tensor_descriptor %16 : !tt.ptr to !tt.tensordesc> loc(#loc38) - %19 = arith.extsi %arg12 : i32 to i64 loc(#loc39) - %20 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr loc(#loc39) - %21 = arith.muli %19, %c2_i64 : i64 loc(#loc39) - tt.experimental_tensormap_create %20, %arg10, [%c64_i32, %c128_i32], [%arg19, %arg18], [%21], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc39) - tt.experimental_tensormap_fenceproxy_acquire %20 : !tt.ptr loc(#loc39) - %22 = tt.reinterpret_tensor_descriptor %20 : !tt.ptr to !tt.tensordesc> loc(#loc39) - %23 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 128 : i32} : !tt.ptr loc(#loc40) - %24 = arith.muli %11, %c2_i64 : i64 loc(#loc40) - tt.experimental_tensormap_create %23, %arg4, [%c64_i32, %c1_i32], [%arg18, %c1073741824_i32], [%24], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc40) - tt.experimental_tensormap_fenceproxy_acquire %23 : !tt.ptr loc(#loc40) - %25 = arith.addi %arg19, %c63_i32 : i32 loc(#loc118) - %26 = arith.divsi %25, %c64_i32 : i32 loc(#loc119) - %27 = arith.subi %arg28, %1 : i32 loc(#loc44) - %28 = arith.muli %27, %arg29 : i32 loc(#loc45) - %29 = arith.divsi %28, %c148_i32 : i32 loc(#loc46) - %30 = arith.remsi %28, %c148_i32 : i32 loc(#loc47) - %31 = arith.cmpi slt, %2, %30 : i32 loc(#loc48) - %32 = scf.if %31 -> (i32) { - %83 = arith.addi %29, %c1_i32 : i32 loc(#loc50) - scf.yield %83 : i32 loc(#loc50) - } else { - scf.yield %29 : i32 loc(#loc1) - } loc(#loc49) - %33 = arith.subi %2, %c148_i32 : i32 loc(#loc51) - %34 = arith.muli %26, %32 : i32 loc(#loc52) - %35 = arith.subi %26, %c1_i32 : i32 loc(#loc53) - %36 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32} : !tt.ptr loc(#loc54) - %37 = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 512 : i32} : !tt.ptr loc(#loc55) - %38 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> loc(#loc56) - %39 = ttg.local_alloc : () -> !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> loc(#loc57) - %40 = ttg.local_alloc : () -> !ttg.memdesc<4xi64, #shared1, #smem, mutable> loc(#loc58) - %41 = ttg.memdesc_subview %40[%c0_i32] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.init_barrier %41, 1 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %42 = ttg.memdesc_subview %40[%c1_i32] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.init_barrier %42, 1 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %43 = ttg.memdesc_subview %40[%c2_i32] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.init_barrier %43, 1 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %44 = ttg.memdesc_subview %40[%c3_i32] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.init_barrier %44, 1 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %45 = arith.cmpi sgt, %34, %c0_i32 : i32 loc(#loc58) - %46 = arith.select %45, %2, %33 : i32 loc(#loc59) - %47 = arith.extui %45 : i1 to i32 loc(#loc59) - %48:7 = scf.if %45 -> (!tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32) { - %83 = arith.remsi %2, %28 : i32 loc(#loc60) - %84 = arith.divsi %83, %5 : i32 loc(#loc120) - %85 = arith.muli %84, %c8_i32 : i32 loc(#loc121) - %86 = arith.subi %27, %85 : i32 loc(#loc122) - %87 = arith.minsi %86, %c8_i32 : i32 loc(#loc123) - %88 = arith.remsi %83, %87 : i32 loc(#loc124) - %89 = arith.addi %85, %88 : i32 loc(#loc125) - %90 = arith.remsi %83, %5 : i32 loc(#loc126) - %91 = arith.divsi %90, %87 : i32 loc(#loc127) - %92 = tt.addptr %arg27, %89 : !tt.ptr, i32 loc(#loc62) - %93 = tt.load %92 : !tt.ptr loc(#loc63) - %94 = arith.andi %93, %c65535_i32 : i32 loc(#loc64) - %95 = arith.shrsi %93, %c16_i32 : i32 loc(#loc65) - %96 = tt.addptr %arg24, %94 : !tt.ptr, i32 loc(#loc66) - %97 = tt.load %96 : !tt.ptr loc(#loc67) - %98 = tt.addptr %arg25, %94 : !tt.ptr, i32 loc(#loc68) - %99 = tt.load %98 : !tt.ptr loc(#loc69) - %100 = arith.muli %95, %c128_i32 : i32 loc(#loc70) - %101 = arith.muli %91, %c128_i32 : i32 loc(#loc71) - %102 = arith.extsi %99 : i32 to i64 loc(#loc72) - %103 = arith.muli %102, %15 : i64 loc(#loc73) - %104 = tt.addptr %arg7, %103 : !tt.ptr, i64 loc(#loc74) - %105 = arith.muli %15, %c2_i64 : i64 loc(#loc54) - tt.experimental_tensormap_create %36, %104, [%c64_i32, %c128_i32], [%arg19, %97], [%105], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc54) - tt.experimental_tensormap_fenceproxy_acquire %36 : !tt.ptr loc(#loc54) - %106 = tt.reinterpret_tensor_descriptor %36 : !tt.ptr to !tt.tensordesc> loc(#loc54) - %107 = arith.extsi %94 : i32 to i64 loc(#loc75) - %108 = arith.extsi %arg11 : i32 to i64 loc(#loc76) - %109 = arith.muli %107, %108 : i64 loc(#loc76) - %110 = tt.addptr %arg10, %109 : !tt.ptr, i64 loc(#loc77) - %111 = arith.muli %19, %c2_i64 : i64 loc(#loc55) - tt.experimental_tensormap_create %37, %110, [%c64_i32, %c128_i32], [%arg19, %arg18], [%111], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc55) - tt.experimental_tensormap_fenceproxy_acquire %37 : !tt.ptr loc(#loc55) - %112 = tt.reinterpret_tensor_descriptor %37 : !tt.ptr to !tt.tensordesc> loc(#loc55) - scf.yield %106, %112, %94, %99, %100, %101, %97 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32 loc(#loc55) - } else { - scf.yield %18, %22, %c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32 loc(#loc1) - } loc(#loc59) - ttng.barrier_expect %41, 32768, %45 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %49 = ttg.memdesc_subview %38[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %50 = ttng.tensor_desc_to_tma_ptr %48#0 : !tt.tensordesc> to !tt.ptr loc(#loc56) - ttng.async_tma_copy_global_to_local %50[%48#4, %c0_i32] %49, %41, %45 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %51 = ttg.memdesc_subview %39[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %52 = ttng.tensor_desc_to_tma_ptr %48#1 : !tt.tensordesc> to !tt.ptr loc(#loc57) - ttng.async_tma_copy_global_to_local %52[%48#5, %c0_i32] %51, %41, %45 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %53 = arith.cmpi sgt, %34, %c1_i32 : i32 loc(#loc58) - %54 = arith.cmpi ne, %35, %c0_i32 : i32 loc(#loc128) - %55 = arith.extui %54 : i1 to i32 loc(#loc78) - %56 = arith.cmpi eq, %55, %c0_i32 : i32 loc(#loc80) - %57 = arith.andi %53, %56 : i1 loc(#loc58) - %58:10 = scf.if %57 -> (!tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32) { - %83 = arith.addi %46, %c148_i32 : i32 loc(#loc81) - %84 = arith.remsi %83, %28 : i32 loc(#loc60) - %85 = arith.divsi %84, %5 : i32 loc(#loc120) - %86 = arith.muli %85, %c8_i32 : i32 loc(#loc121) - %87 = arith.subi %27, %86 : i32 loc(#loc122) - %88 = arith.minsi %87, %c8_i32 : i32 loc(#loc123) - %89 = arith.remsi %84, %88 : i32 loc(#loc124) - %90 = arith.addi %86, %89 : i32 loc(#loc125) - %91 = arith.remsi %84, %5 : i32 loc(#loc126) - %92 = arith.divsi %91, %88 : i32 loc(#loc127) - %93 = tt.addptr %arg27, %90 : !tt.ptr, i32 loc(#loc62) - %94 = tt.load %93 : !tt.ptr loc(#loc63) - %95 = arith.andi %94, %c65535_i32 : i32 loc(#loc64) - %96 = arith.shrsi %94, %c16_i32 : i32 loc(#loc65) - %97 = tt.addptr %arg24, %95 : !tt.ptr, i32 loc(#loc66) - %98 = tt.load %97 : !tt.ptr loc(#loc67) - %99 = tt.addptr %arg25, %95 : !tt.ptr, i32 loc(#loc68) - %100 = tt.load %99 : !tt.ptr loc(#loc69) - %101 = arith.muli %96, %c128_i32 : i32 loc(#loc70) - %102 = arith.muli %92, %c128_i32 : i32 loc(#loc71) - %103 = arith.extsi %100 : i32 to i64 loc(#loc72) - %104 = arith.muli %103, %15 : i64 loc(#loc73) - %105 = tt.addptr %arg7, %104 : !tt.ptr, i64 loc(#loc74) - %106 = arith.muli %47, %c128_i32 : i32 loc(#loc54) - %107 = tt.addptr %36, %106 : !tt.ptr, i32 loc(#loc54) - %108 = arith.muli %15, %c2_i64 : i64 loc(#loc54) - tt.experimental_tensormap_create %107, %105, [%c64_i32, %c128_i32], [%arg19, %98], [%108], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc54) - tt.experimental_tensormap_fenceproxy_acquire %107 : !tt.ptr loc(#loc54) - %109 = tt.reinterpret_tensor_descriptor %107 : !tt.ptr to !tt.tensordesc> loc(#loc54) - %110 = arith.addi %47, %c1_i32 : i32 loc(#loc54) - %111 = arith.cmpi slt, %110, %c4_i32 : i32 loc(#loc54) - %112 = arith.select %111, %110, %c0_i32 : i32 loc(#loc54) - %113 = arith.extsi %95 : i32 to i64 loc(#loc75) - %114 = arith.extsi %arg11 : i32 to i64 loc(#loc76) - %115 = arith.muli %113, %114 : i64 loc(#loc76) - %116 = tt.addptr %arg10, %115 : !tt.ptr, i64 loc(#loc77) - %117 = tt.addptr %37, %106 : !tt.ptr, i32 loc(#loc55) - %118 = arith.muli %19, %c2_i64 : i64 loc(#loc55) - tt.experimental_tensormap_create %117, %116, [%c64_i32, %c128_i32], [%arg19, %arg18], [%118], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc55) - tt.experimental_tensormap_fenceproxy_acquire %117 : !tt.ptr loc(#loc55) - %119 = tt.reinterpret_tensor_descriptor %117 : !tt.ptr to !tt.tensordesc> loc(#loc55) - scf.yield %109, %119, %83, %95, %100, %101, %102, %98, %112, %112 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc55) - } else { - scf.yield %48#0, %48#1, %46, %48#2, %48#3, %48#4, %48#5, %48#6, %47, %47 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc1) - } loc(#loc59) - %59 = arith.muli %55, %c64_i32 : i32 loc(#loc82) - ttng.barrier_expect %42, 32768, %53 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %60 = ttg.memdesc_subview %38[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %61 = ttng.tensor_desc_to_tma_ptr %58#0 : !tt.tensordesc> to !tt.ptr loc(#loc56) - ttng.async_tma_copy_global_to_local %61[%58#5, %59] %60, %42, %53 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %62 = ttg.memdesc_subview %39[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %63 = ttng.tensor_desc_to_tma_ptr %58#1 : !tt.tensordesc> to !tt.ptr loc(#loc57) - ttng.async_tma_copy_global_to_local %63[%58#6, %59] %62, %42, %53 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %64 = arith.cmpi sgt, %34, %c2_i32 : i32 loc(#loc58) - %65 = arith.cmpi eq, %55, %35 : i32 loc(#loc79) - %66 = arith.addi %55, %c1_i32 : i32 loc(#loc83) - %67 = arith.select %65, %c0_i32, %66 : i32 loc(#loc78) - %68 = arith.cmpi eq, %67, %c0_i32 : i32 loc(#loc80) - %69 = arith.andi %64, %68 : i1 loc(#loc58) - %70:10 = scf.if %69 -> (!tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32) { - %83 = arith.addi %58#2, %c148_i32 : i32 loc(#loc81) - %84 = arith.remsi %83, %28 : i32 loc(#loc60) - %85 = arith.divsi %84, %5 : i32 loc(#loc120) - %86 = arith.muli %85, %c8_i32 : i32 loc(#loc121) - %87 = arith.subi %27, %86 : i32 loc(#loc122) - %88 = arith.minsi %87, %c8_i32 : i32 loc(#loc123) - %89 = arith.remsi %84, %88 : i32 loc(#loc124) - %90 = arith.addi %86, %89 : i32 loc(#loc125) - %91 = arith.remsi %84, %5 : i32 loc(#loc126) - %92 = arith.divsi %91, %88 : i32 loc(#loc127) - %93 = tt.addptr %arg27, %90 : !tt.ptr, i32 loc(#loc62) - %94 = tt.load %93 : !tt.ptr loc(#loc63) - %95 = arith.andi %94, %c65535_i32 : i32 loc(#loc64) - %96 = arith.shrsi %94, %c16_i32 : i32 loc(#loc65) - %97 = tt.addptr %arg24, %95 : !tt.ptr, i32 loc(#loc66) - %98 = tt.load %97 : !tt.ptr loc(#loc67) - %99 = tt.addptr %arg25, %95 : !tt.ptr, i32 loc(#loc68) - %100 = tt.load %99 : !tt.ptr loc(#loc69) - %101 = arith.muli %96, %c128_i32 : i32 loc(#loc70) - %102 = arith.muli %92, %c128_i32 : i32 loc(#loc71) - %103 = arith.extsi %100 : i32 to i64 loc(#loc72) - %104 = arith.muli %103, %15 : i64 loc(#loc73) - %105 = tt.addptr %arg7, %104 : !tt.ptr, i64 loc(#loc74) - %106 = arith.muli %58#8, %c128_i32 : i32 loc(#loc54) - %107 = tt.addptr %36, %106 : !tt.ptr, i32 loc(#loc54) - %108 = arith.muli %15, %c2_i64 : i64 loc(#loc54) - tt.experimental_tensormap_create %107, %105, [%c64_i32, %c128_i32], [%arg19, %98], [%108], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc54) - tt.experimental_tensormap_fenceproxy_acquire %107 : !tt.ptr loc(#loc54) - %109 = tt.reinterpret_tensor_descriptor %107 : !tt.ptr to !tt.tensordesc> loc(#loc54) - %110 = arith.addi %58#8, %c1_i32 : i32 loc(#loc54) - %111 = arith.cmpi slt, %110, %c4_i32 : i32 loc(#loc54) - %112 = arith.select %111, %110, %c0_i32 : i32 loc(#loc54) - %113 = arith.extsi %95 : i32 to i64 loc(#loc75) - %114 = arith.extsi %arg11 : i32 to i64 loc(#loc76) - %115 = arith.muli %113, %114 : i64 loc(#loc76) - %116 = tt.addptr %arg10, %115 : !tt.ptr, i64 loc(#loc77) - %117 = arith.muli %58#9, %c128_i32 : i32 loc(#loc55) - %118 = tt.addptr %37, %117 : !tt.ptr, i32 loc(#loc55) - %119 = arith.muli %19, %c2_i64 : i64 loc(#loc55) - tt.experimental_tensormap_create %118, %116, [%c64_i32, %c128_i32], [%arg19, %arg18], [%119], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc55) - tt.experimental_tensormap_fenceproxy_acquire %118 : !tt.ptr loc(#loc55) - %120 = tt.reinterpret_tensor_descriptor %118 : !tt.ptr to !tt.tensordesc> loc(#loc55) - %121 = arith.addi %58#9, %c1_i32 : i32 loc(#loc55) - %122 = arith.cmpi slt, %121, %c4_i32 : i32 loc(#loc55) - %123 = arith.select %122, %121, %c0_i32 : i32 loc(#loc55) - scf.yield %109, %120, %83, %95, %100, %101, %102, %98, %112, %123 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc55) - } else { - scf.yield %58#0, %58#1, %58#2, %58#3, %58#4, %58#5, %58#6, %58#7, %58#8, %58#9 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc1) - } loc(#loc59) - %71 = arith.muli %67, %c64_i32 : i32 loc(#loc82) - ttng.barrier_expect %43, 32768, %64 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %72 = ttg.memdesc_subview %38[%c2_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %73 = ttng.tensor_desc_to_tma_ptr %70#0 : !tt.tensordesc> to !tt.ptr loc(#loc56) - ttng.async_tma_copy_global_to_local %73[%70#5, %71] %72, %43, %64 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %74 = ttg.memdesc_subview %39[%c2_i32, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %75 = ttng.tensor_desc_to_tma_ptr %70#1 : !tt.tensordesc> to !tt.ptr loc(#loc57) - ttng.async_tma_copy_global_to_local %75[%70#6, %71] %74, %43, %64 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %76 = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc84) - %77 = ttg.memdesc_subview %76[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc84) - ttng.tmem_store %cst_4, %77, %true : tensor<128x128xf32, #blocked2> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc84) - %78 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #shared1, #smem, mutable> loc(#loc58) - %79 = ttg.memdesc_subview %78[%c0_i32] : !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc58) - ttng.init_barrier %79, 1 : <1xi64, #shared1, #smem, mutable> loc(#loc58) - %80 = ttg.memdesc_subview %78[%c1_i32] : !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc58) - ttng.init_barrier %80, 1 : <1xi64, #shared1, #smem, mutable> loc(#loc58) - %81:37 = scf.for %arg30 = %c0_i32 to %34 step %c1_i32 iter_args(%arg31 = %67, %arg32 = %70#0, %arg33 = %70#1, %arg34 = %70#2, %arg35 = %70#3, %arg36 = %70#4, %arg37 = %70#5, %arg38 = %70#6, %arg39 = %70#7, %arg40 = %false, %arg41 = %c2_i32, %arg42 = %c-1_i32, %arg43 = %c0_i32, %arg44 = %70#8, %arg45 = %70#9, %arg46 = %c0_i32, %arg47 = %55, %arg48 = %67, %arg49 = %48#4, %arg50 = %58#5, %arg51 = %70#5, %arg52 = %48#5, %arg53 = %58#6, %arg54 = %70#6, %arg55 = %48#6, %arg56 = %58#7, %arg57 = %70#7, %arg58 = %48#2, %arg59 = %58#3, %arg60 = %70#3, %arg61 = %48#3, %arg62 = %58#4, %arg63 = %70#4, %arg64 = %c0_i32, %arg65 = %c0_i32, %arg66 = %c0_i32, %arg67 = %c0_i32) -> (i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i1, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32) : i32 { - %83 = arith.subi %34, %c3_i32 : i32 loc(#loc58) - %84 = arith.cmpi slt, %arg30, %83 : i32 loc(#loc58) - %85 = arith.cmpi eq, %arg31, %35 : i32 loc(#loc79) - %86 = arith.addi %arg31, %c1_i32 : i32 loc(#loc83) - %87 = arith.select %85, %c0_i32, %86 : i32 loc(#loc78) - %88 = arith.cmpi eq, %87, %c0_i32 : i32 loc(#loc80) - %89 = arith.andi %84, %88 : i1 loc(#loc58) - %90:10 = scf.if %89 -> (!tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32) { - %126 = arith.addi %arg34, %c148_i32 : i32 loc(#loc81) - %127 = arith.remsi %126, %28 : i32 loc(#loc60) - %128 = arith.divsi %127, %5 : i32 loc(#loc120) - %129 = arith.muli %128, %c8_i32 : i32 loc(#loc121) - %130 = arith.subi %27, %129 : i32 loc(#loc122) - %131 = arith.minsi %130, %c8_i32 : i32 loc(#loc123) - %132 = arith.remsi %127, %131 : i32 loc(#loc124) - %133 = arith.addi %129, %132 : i32 loc(#loc125) - %134 = arith.remsi %127, %5 : i32 loc(#loc126) - %135 = arith.divsi %134, %131 : i32 loc(#loc127) - %136 = tt.addptr %arg27, %133 : !tt.ptr, i32 loc(#loc62) - %137 = tt.load %136 : !tt.ptr loc(#loc63) - %138 = arith.andi %137, %c65535_i32 : i32 loc(#loc64) - %139 = arith.shrsi %137, %c16_i32 : i32 loc(#loc65) - %140 = tt.addptr %arg24, %138 : !tt.ptr, i32 loc(#loc66) - %141 = tt.load %140 : !tt.ptr loc(#loc67) - %142 = tt.addptr %arg25, %138 : !tt.ptr, i32 loc(#loc68) - %143 = tt.load %142 : !tt.ptr loc(#loc69) - %144 = arith.muli %139, %c128_i32 : i32 loc(#loc70) - %145 = arith.muli %135, %c128_i32 : i32 loc(#loc71) - %146 = arith.extsi %143 : i32 to i64 loc(#loc72) - %147 = arith.muli %146, %15 : i64 loc(#loc73) - %148 = tt.addptr %arg7, %147 : !tt.ptr, i64 loc(#loc74) - %149 = arith.muli %arg44, %c128_i32 : i32 loc(#loc54) - %150 = tt.addptr %36, %149 : !tt.ptr, i32 loc(#loc54) - %151 = arith.muli %15, %c2_i64 : i64 loc(#loc54) - tt.experimental_tensormap_create %150, %148, [%c64_i32, %c128_i32], [%arg19, %141], [%151], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc54) - tt.experimental_tensormap_fenceproxy_acquire %150 : !tt.ptr loc(#loc54) - %152 = tt.reinterpret_tensor_descriptor %150 : !tt.ptr to !tt.tensordesc> loc(#loc54) - %153 = arith.addi %arg44, %c1_i32 : i32 loc(#loc54) - %154 = arith.cmpi slt, %153, %c4_i32 : i32 loc(#loc54) - %155 = arith.select %154, %153, %c0_i32 : i32 loc(#loc54) - %156 = arith.extsi %138 : i32 to i64 loc(#loc75) - %157 = arith.extsi %arg11 : i32 to i64 loc(#loc76) - %158 = arith.muli %156, %157 : i64 loc(#loc76) - %159 = tt.addptr %arg10, %158 : !tt.ptr, i64 loc(#loc77) - %160 = arith.muli %arg45, %c128_i32 : i32 loc(#loc55) - %161 = tt.addptr %37, %160 : !tt.ptr, i32 loc(#loc55) - %162 = arith.muli %19, %c2_i64 : i64 loc(#loc55) - tt.experimental_tensormap_create %161, %159, [%c64_i32, %c128_i32], [%arg19, %arg18], [%162], [%c1_i32, %c1_i32] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () loc(#loc55) - tt.experimental_tensormap_fenceproxy_acquire %161 : !tt.ptr loc(#loc55) - %163 = tt.reinterpret_tensor_descriptor %161 : !tt.ptr to !tt.tensordesc> loc(#loc55) - %164 = arith.addi %arg45, %c1_i32 : i32 loc(#loc55) - %165 = arith.cmpi slt, %164, %c4_i32 : i32 loc(#loc55) - %166 = arith.select %165, %164, %c0_i32 : i32 loc(#loc55) - scf.yield %152, %163, %126, %138, %143, %144, %145, %141, %155, %166 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc55) - } else { - scf.yield %arg32, %arg33, %arg34, %arg35, %arg36, %arg37, %arg38, %arg39, %arg44, %arg45 : !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc1) - } loc(#loc59) - %91 = arith.addi %arg42, %c1_i32 : i32 loc(#loc58) - %92 = arith.cmpi slt, %91, %c4_i32 : i32 loc(#loc58) - %93 = arith.select %92, %91, %c0_i32 : i32 loc(#loc58) - %94 = arith.xori %arg43, %c1_i32 : i32 loc(#loc58) - %95 = arith.select %92, %arg43, %94 : i32 loc(#loc58) - %96 = ttg.memdesc_subview %40[%93] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.wait_barrier %96, %95 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %97 = ttg.memdesc_subview %39[%93, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %98 = ttg.memdesc_trans %97 {order = array} : !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> -> !ttg.memdesc<64x128xbf16, #shared2, #smem, mutable> loc(#loc57) - %99 = ttg.memdesc_subview %38[%93, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %100 = ttg.memdesc_subview %76[%arg66, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc84) - %101 = ttg.memdesc_subview %78[%arg65] : !ttg.memdesc<2xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> loc(#loc58) - ttng.tc_gen5_mma %99, %98, %100, %arg40, %true, %101 : (!ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64>, !ttg.memdesc<64x128xbf16, #shared2, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, i1, i1, !ttg.memdesc<1xi64, #shared1, #smem, mutable>) -> () loc(#loc84) - ttng.wait_barrier %101, %arg64 : <1xi64, #shared1, #smem, mutable> loc(#loc84) - %102 = arith.addi %arg65, %c1_i32 : i32 loc(#loc84) - %103 = arith.cmpi eq, %102, %c2_i32 : i32 loc(#loc84) - %104 = arith.select %103, %c0_i32, %102 : i32 loc(#loc84) - %105 = arith.xori %arg64, %c1_i32 : i32 loc(#loc84) - %106 = arith.select %103, %105, %arg64 : i32 loc(#loc84) - %107 = arith.addi %arg41, %c1_i32 : i32 loc(#loc58) - %108 = arith.cmpi slt, %107, %c4_i32 : i32 loc(#loc58) - %109 = arith.select %108, %107, %c0_i32 : i32 loc(#loc58) - %110 = arith.muli %87, %c64_i32 : i32 loc(#loc82) - %111 = ttg.memdesc_subview %40[%109] : !ttg.memdesc<4xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.barrier_expect %111, 32768, %84 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - %112 = ttg.memdesc_subview %38[%109, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %113 = ttng.tensor_desc_to_tma_ptr %90#0 : !tt.tensordesc> to !tt.ptr loc(#loc56) - ttng.async_tma_copy_global_to_local %113[%90#5, %110] %112, %111, %84 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc56) - %114 = ttg.memdesc_subview %39[%109, %c0_i32, %c0_i32] : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %115 = ttng.tensor_desc_to_tma_ptr %90#1 : !tt.tensordesc> to !tt.ptr loc(#loc57) - ttng.async_tma_copy_global_to_local %115[%90#6, %110] %114, %111, %84 : , <1xi64, #shared1, #smem, mutable, 4> -> <128x64xbf16, #shared, #smem, mutable, 4x128x64> loc(#loc57) - %116 = arith.cmpi eq, %arg46, %35 : i32 loc(#loc85) - %117 = arith.cmpi ne, %arg46, %35 : i32 loc(#loc129) - %118 = arith.addi %arg66, %c1_i32 : i32 loc(#loc84) - %119 = arith.cmpi eq, %118, %c2_i32 : i32 loc(#loc84) - %120 = arith.select %119, %c0_i32, %118 : i32 loc(#loc84) - %121 = arith.select %116, %120, %arg66 : i32 loc(#loc84) - %122 = arith.addi %arg67, %c1_i32 : i32 loc(#loc84) - %123 = arith.cmpi eq, %122, %c2_i32 : i32 loc(#loc84) - %124 = arith.select %123, %c0_i32, %122 : i32 loc(#loc84) - %125 = arith.select %116, %124, %arg67 : i32 loc(#loc84) - scf.if %116 { - %126 = tt.splat %arg49 : i32 -> tensor<128xi32, #blocked1> loc(#loc87) - %127 = arith.addi %126, %8 : tensor<128xi32, #blocked1> loc(#loc87) - %128 = tt.splat %arg52 : i32 -> tensor<128xi32, #blocked1> loc(#loc88) - %129 = arith.addi %128, %8 : tensor<128xi32, #blocked1> loc(#loc88) - %130 = tt.splat %arg55 : i32 -> tensor<128xi32, #blocked1> loc(#loc89) - %131 = arith.cmpi slt, %127, %130 : tensor<128xi32, #blocked1> loc(#loc89) - %132 = tt.splat %arg18 : i32 -> tensor<128xi32, #blocked1> loc(#loc90) - %133 = arith.cmpi slt, %129, %132 : tensor<128xi32, #blocked1> loc(#loc90) - %134 = arith.muli %arg58, %arg17 : i32 loc(#loc91) - %135 = tt.addptr %arg16, %134 : !tt.ptr, i32 loc(#loc92) - %136 = tt.splat %135 : !tt.ptr -> tensor<128x!tt.ptr, #blocked1> loc(#loc93) - %137 = tt.addptr %136, %129 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> loc(#loc93) - %138 = tt.load %137, %133, %cst_3 : tensor<128x!tt.ptr, #blocked1> loc(#loc94) - %139 = ttg.convert_layout %138 : tensor<128xf32, #blocked1> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> loc(#loc95) - %140 = tt.expand_dims %139 {axis = 0 : i32} : tensor<128xf32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xf32, #blocked2> loc(#loc96) - %141 = tt.broadcast %140 : tensor<1x128xf32, #blocked2> -> tensor<128x128xf32, #blocked2> loc(#loc95) - %142 = ttg.memdesc_subview %76[%arg67, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> loc(#loc84) - %143 = ttng.tmem_load %142 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked2> loc(#loc97) - %144 = arith.addf %143, %141 : tensor<128x128xf32, #blocked2> loc(#loc97) - %145 = tt.splat %arg61 : i32 -> tensor<128xi32, #blocked1> loc(#loc98) - %146 = arith.addi %145, %127 : tensor<128xi32, #blocked1> loc(#loc98) - %147 = tt.splat %arg23 : i32 -> tensor<128xi32, #blocked1> loc(#loc99) - %148 = arith.cmpi slt, %146, %147 : tensor<128xi32, #blocked1> loc(#loc99) - %149 = tt.addptr %arg22, %arg61 : !tt.ptr, i32 loc(#loc100) - %150 = tt.splat %149 : !tt.ptr -> tensor<128x!tt.ptr, #blocked1> loc(#loc101) - %151 = tt.addptr %150, %127 : tensor<128x!tt.ptr, #blocked1>, tensor<128xi32, #blocked1> loc(#loc101) - %152 = tt.load %151, %148, %cst_1 : tensor<128x!tt.ptr, #blocked1> loc(#loc102) - %153 = arith.cmpi ne, %152, %cst_1 : tensor<128xi32, #blocked1> loc(#loc103) - %154 = arith.andi %131, %153 : tensor<128xi1, #blocked1> loc(#loc104) - %155 = arith.select %154, %152, %cst_2 : tensor<128xi1, #blocked1>, tensor<128xi32, #blocked1> loc(#loc105) - %156 = arith.truncf %144 : tensor<128x128xf32, #blocked2> to tensor<128x128xbf16, #blocked2> loc(#loc106) - %157 = ttg.convert_layout %155 : tensor<128xi32, #blocked1> -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>> loc(#loc107) - %158 = ttg.local_alloc %156 : (tensor<128x128xbf16, #blocked2>) -> !ttg.memdesc<128x128xbf16, #shared, #smem, mutable> loc(#loc107) - ttng.fence_async_shared {bCluster = false} loc(#loc107) - ttng.async_tma_scatter %23[%157, %arg52] %158 : !tt.ptr, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked3}>>, i32, !ttg.memdesc<128x128xbf16, #shared, #smem, mutable> loc(#loc107) - ttng.async_tma_store_wait {pendings = 0 : i32} loc(#loc107) - } loc(#loc86) - scf.yield %87, %90#0, %90#1, %90#2, %90#3, %90#4, %90#5, %90#6, %90#7, %117, %109, %93, %95, %90#8, %90#9, %arg47, %arg48, %87, %arg50, %arg51, %90#5, %arg53, %arg54, %90#6, %arg56, %arg57, %90#7, %arg59, %arg60, %90#3, %arg62, %arg63, %90#4, %106, %104, %121, %125 : i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i1, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32 loc(#loc58) - } {triton.pipeline} loc(#loc58) - ttng.inval_barrier %79 : <1xi64, #shared1, #smem, mutable> loc(#loc84) - ttng.inval_barrier %80 : <1xi64, #shared1, #smem, mutable> loc(#loc84) - ttg.local_dealloc %78 : !ttg.memdesc<2xi64, #shared1, #smem, mutable> loc(#loc84) - %82 = ttg.async_wait {num = 0 : i32} loc(#loc58) - ttng.inval_barrier %41 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.inval_barrier %42 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.inval_barrier %43 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttng.inval_barrier %44 : <1xi64, #shared1, #smem, mutable, 4> loc(#loc58) - ttg.local_dealloc %38 : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> loc(#loc58) - ttg.local_dealloc %39 : !ttg.memdesc<4x128x64xbf16, #shared, #smem, mutable> loc(#loc58) - tt.return loc(#loc108) - } loc(#loc) -} loc(#loc) -#loc1 = loc(unknown) -#loc2 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":113:37) -#loc3 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":113:29) -#loc4 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":126:52) -#loc5 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":126:57) -#loc6 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":126:82) -#loc7 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":78:22) -#loc8 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":127:61) -#loc9 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":129:52) -#loc10 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":131:69) -#loc11 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":131:47) -#loc12 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:55) -#loc13 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:24) -#loc14 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":133:30) -#loc15 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":126:90) -#loc16 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":79:22) -#loc17 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":80:41) -#loc18 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":80:30) -#loc19 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":80:50) -#loc20 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":81:40) -#loc21 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":81:34) -#loc22 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":82:19) -#loc23 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_common.py":82:30) -#loc24 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":129:31) -#loc25 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":129:39) -#loc26 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":130:31) -#loc27 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":130:39) -#loc28 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":131:30) -#loc29 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:34) -#loc30 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:44) -#loc31 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:75) -#loc32 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":132:68) -#loc33 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":134:31) -#loc34 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":134:46) -#loc35 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":134:35) -#loc36 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":134:53) -#loc37 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":135:28) -#loc38 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":149:12) -#loc39 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":168:31) -#loc40 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":185:12) -#loc41 = loc("/home/jeffniu/code/pytorch/.venv/lib/python3.12/site-packages/triton/language/standard.py":40:22) -#loc42 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":200:25) -#loc43 = loc("/home/jeffniu/code/pytorch/.venv/lib/python3.12/site-packages/triton/language/standard.py":40:28) -#loc44 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":201:39) -#loc45 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":201:52) -#loc46 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":202:33) -#loc47 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":203:38) -#loc48 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":203:26) -#loc49 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":203:7) -#loc50 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":204:25) -#loc51 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":206:33) -#loc52 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":213:29) -#loc53 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":214:38) -#loc54 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":243:20) -#loc55 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":272:39) -#loc56 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":308:28) -#loc57 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":320:38) -#loc58 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":213:19) -#loc59 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":215:11) -#loc60 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":222:32) -#loc61 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":223:75) -#loc62 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":231:47) -#loc63 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":231:36) -#loc64 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":232:38) -#loc65 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":233:40) -#loc66 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":234:40) -#loc67 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":234:29) -#loc68 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":235:45) -#loc69 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":235:34) -#loc70 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":238:30) -#loc71 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":239:30) -#loc72 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":243:71) -#loc73 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":243:83) -#loc74 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":243:60) -#loc75 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":272:54) -#loc76 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":272:66) -#loc77 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":272:43) -#loc78 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":214:44) -#loc79 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":214:28) -#loc80 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":215:17) -#loc81 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":216:23) -#loc82 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":308:41) -#loc83 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":214:49) -#loc84 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":326:31) -#loc85 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":328:17) -#loc86 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":328:11) -#loc87 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":330:29) -#loc88 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":331:31) -#loc89 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":332:63) -#loc90 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":333:32) -#loc91 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":335:38) -#loc92 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":335:28) -#loc93 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":335:51) -#loc94 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":336:31) -#loc95 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":353:40) -#loc96 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":353:29) -#loc97 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":353:24) -#loc98 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":361:83) -#loc99 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":361:92) -#loc100 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":361:50) -#loc101 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":361:60) -#loc102 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":361:34) -#loc103 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":362:46) -#loc104 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":362:35) -#loc105 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":369:54) -#loc106 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":370:38) -#loc107 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":370:69) -#loc108 = loc("/home/jeffniu/code/openai/lib/ki/ki/matmul_details/_kernels/_p_matmul_ogs.py":213:4) -#loc109 = loc(callsite(#loc7 at #loc8)) -#loc110 = loc(callsite(#loc16 at #loc8)) -#loc111 = loc(callsite(#loc17 at #loc8)) -#loc112 = loc(callsite(#loc18 at #loc8)) -#loc113 = loc(callsite(#loc19 at #loc8)) -#loc114 = loc(callsite(#loc20 at #loc8)) -#loc115 = loc(callsite(#loc21 at #loc8)) -#loc116 = loc(callsite(#loc22 at #loc8)) -#loc117 = loc(callsite(#loc23 at #loc8)) -#loc118 = loc(callsite(#loc41 at #loc42)) -#loc119 = loc(callsite(#loc43 at #loc42)) -#loc120 = loc(callsite(#loc16 at #loc61)) -#loc121 = loc(callsite(#loc17 at #loc61)) -#loc122 = loc(callsite(#loc18 at #loc61)) -#loc123 = loc(callsite(#loc19 at #loc61)) -#loc124 = loc(callsite(#loc20 at #loc61)) -#loc125 = loc(callsite(#loc21 at #loc61)) -#loc126 = loc(callsite(#loc22 at #loc61)) -#loc127 = loc(callsite(#loc23 at #loc61)) -#loc128 = loc(fused[#loc78, #loc79]) -#loc129 = loc(fused[#loc84, #loc85]) diff --git a/test.mlir b/test.mlir deleted file mode 100644 index f87818cc5a33..000000000000 --- a/test.mlir +++ /dev/null @@ -1,44 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr, #blocked>, %arg4: tensor<128x128x!tt.ptr, #blocked>) { - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %cst_0 = arith.constant dense<1> : tensor<128x128xi32, #blocked> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %0:4 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c0_i32, %arg7 = %arg4, %arg8 = %arg4, %arg9 = %arg4) -> (i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>) : i32 { - %1 = arith.cmpi eq, %arg6, %c0_i32 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 - %2 = arith.addi %arg6, %c1_i32 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : i32 - %3:2 = scf.if %1 -> (tensor<128x128xf32, #blocked>, i32) { - %11 = tt.addptr %arg8, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %12 = tt.load %11 : tensor<128x128x!tt.ptr, #blocked> - %13 = tt.load %arg8 : tensor<128x128x!tt.ptr, #blocked> - %14 = arith.addf %12, %13 : tensor<128x128xf32, #blocked> - %15 = arith.addi %arg5, %arg5 : i32 - scf.yield %14, %15 : tensor<128x128xf32, #blocked>, i32 - } else { - %11 = tt.load %arg8 : tensor<128x128x!tt.ptr, #blocked> - %12 = arith.addf %11, %cst : tensor<128x128xf32, #blocked> - %13 = arith.addi %arg5, %arg5 : i32 - scf.yield %12, %13 : tensor<128x128xf32, #blocked>, i32 - } {b, loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.conditional_load} - %4 = tt.addptr %arg7, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %5 = tt.addptr %arg8, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %6 = tt.addptr %arg9, %cst_0 {loop.cluster = 3 : i32, loop.stage = 1 : i32} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %7 = tt.load %arg7 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : tensor<128x128x!tt.ptr, #blocked> - %8 = arith.addf %7, %3#0 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : tensor<128x128xf32, #blocked> - %9 = arith.cmpi eq, %arg6, %arg1 {loop.cluster = 4 : i32, loop.stage = 0 : i32} : i32 - %10 = scf.if %9 -> (tensor<128x128xf32, #blocked>) { - %11 = tt.load %arg9 : tensor<128x128x!tt.ptr, #blocked> - scf.yield %11 : tensor<128x128xf32, #blocked> - } else { - scf.yield %cst : tensor<128x128xf32, #blocked> - } {a, loop.cluster = 4 : i32, loop.stage = 0 : i32, ttg.conditional_load} - scf.if %9 { - %11 = arith.addf %10, %8 : tensor<128x128xf32, #blocked> - tt.store %arg4, %11 : tensor<128x128x!tt.ptr, #blocked> - } {loop.cluster = 5 : i32, loop.stage = 2 : i32} - scf.yield %2, %4, %5, %6 : i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked> - } {tt.num_stages = 3 : i32} - tt.return - } -} diff --git a/test1.mlir b/test1.mlir deleted file mode 100644 index d46fedb1728b..000000000000 --- a/test1.mlir +++ /dev/null @@ -1,219 +0,0 @@ -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> -#smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - tt.func @foo(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: tensor<128x128x!tt.ptr, #blocked>, %arg4: tensor<128x128x!tt.ptr, #blocked>) { - %false = arith.constant false - %true = arith.constant true - %0 = ub.poison : tensor<128x128xf32, #blocked> - %1 = ub.poison : i32 - %2 = ub.poison : !ttg.async.token - %c2_i32 = arith.constant 2 : i32 - %c-1_i32 = arith.constant -1 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %cst_0 = arith.constant dense<1> : tensor<128x128xi32, #blocked> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %3 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - %4 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - %5 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - %6 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - %7 = ttg.local_alloc : () -> !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - %8 = arith.cmpi slt, %arg0, %arg1 : i32 - %9:3 = scf.if %8 -> (!ttg.async.token, !ttg.async.token, i32) { - %34 = tt.addptr %arg4, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %35 = ttg.memdesc_subview %3[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_copy_global_to_local %34, %35 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %37 = ttg.async_commit_group %36 - %38 = ttg.memdesc_subview %4[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %39 = ttg.async_copy_global_to_local %arg4, %38 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %40 = ttg.async_commit_group %39 - %41 = arith.addi %arg0, %arg0 : i32 - scf.yield %37, %40, %41 : !ttg.async.token, !ttg.async.token, i32 - } else { - scf.yield %2, %2, %1 : !ttg.async.token, !ttg.async.token, i32 - } - %10:2 = scf.if %8 -> (!ttg.async.token, i32) { - scf.yield %2, %1 : !ttg.async.token, i32 - } else { - %34 = ttg.memdesc_subview %5[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %35 = ttg.async_copy_global_to_local %arg4, %34 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_commit_group %35 - %37 = arith.addi %arg0, %arg0 : i32 - scf.yield %36, %37 : !ttg.async.token, i32 - } - %11 = ttg.memdesc_subview %6[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %12 = tt.splat %8 : i1 -> tensor<128x128xi1, #blocked> - %13 = ttg.async_copy_global_to_local %arg4, %11 mask %12 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %14 = ttg.async_commit_group %13 - %15 = arith.cmpi eq, %arg1, %c0_i32 : i32 - %16 = arith.andi %8, %15 : i1 - %17 = scf.if %16 -> (!ttg.async.token) { - %34 = ttg.memdesc_subview %7[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %35 = ttg.async_copy_global_to_local %arg4, %34 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_commit_group %35 - scf.yield %36 : !ttg.async.token - } else { - scf.yield %2 : !ttg.async.token - } - %18 = arith.addi %arg0, %arg2 : i32 - %19 = arith.cmpi slt, %18, %arg1 : i32 - %20 = tt.addptr %arg4, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %21 = tt.addptr %arg4, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %22 = tt.addptr %arg4, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %23:3 = scf.if %false -> (!ttg.async.token, !ttg.async.token, i32) { - %34 = tt.addptr %21, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %35 = ttg.memdesc_subview %3[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_copy_global_to_local %34, %35 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %37 = ttg.async_commit_group %36 - %38 = ttg.memdesc_subview %4[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %39 = ttg.async_copy_global_to_local %21, %38 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %40 = ttg.async_commit_group %39 - %41 = arith.addi %18, %18 : i32 - scf.yield %37, %40, %41 : !ttg.async.token, !ttg.async.token, i32 - } else { - scf.yield %2, %2, %1 : !ttg.async.token, !ttg.async.token, i32 - } - %24:2 = scf.if %false -> (!ttg.async.token, i32) { - scf.yield %2, %1 : !ttg.async.token, i32 - } else { - %34 = ttg.memdesc_subview %5[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %35 = ttg.async_copy_global_to_local %21, %34 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_commit_group %35 - %37 = arith.addi %18, %18 : i32 - scf.yield %36, %37 : !ttg.async.token, i32 - } - %25 = ttg.memdesc_subview %6[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %26 = tt.splat %19 : i1 -> tensor<128x128xi1, #blocked> - %27 = ttg.async_copy_global_to_local %20, %25 mask %26 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %28 = ttg.async_commit_group %27 - %29 = arith.cmpi eq, %arg1, %c1_i32 : i32 - %30 = arith.andi %19, %29 : i1 - %31 = scf.if %30 -> (!ttg.async.token) { - %34 = ttg.memdesc_subview %7[%c1_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %35 = ttg.async_copy_global_to_local %22, %34 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %36 = ttg.async_commit_group %35 - scf.yield %36 : !ttg.async.token - } else { - scf.yield %2 : !ttg.async.token - } - %32:24 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %c1_i32, %arg7 = %20, %arg8 = %21, %arg9 = %22, %arg10 = %c1_i32, %arg11 = %c-1_i32, %arg12 = %true, %arg13 = %false, %arg14 = %9#0, %arg15 = %23#0, %arg16 = %9#1, %arg17 = %23#1, %arg18 = %10#0, %arg19 = %24#0, %arg20 = %9#2, %arg21 = %23#2, %arg22 = %10#1, %arg23 = %24#1, %arg24 = %14, %arg25 = %28, %arg26 = %15, %arg27 = %29, %arg28 = %17, %arg29 = %31) -> (i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, !ttg.async.token, !ttg.async.token) : i32 { - %34 = arith.muli %arg2, %c2_i32 : i32 - %35 = arith.subi %arg1, %34 : i32 - %36 = arith.cmpi slt, %arg5, %35 : i32 - %37 = arith.addi %arg11, %c1_i32 : i32 - %38 = arith.cmpi slt, %37, %c2_i32 : i32 - %39 = arith.select %38, %37, %c0_i32 : i32 - %40 = scf.if %arg12 -> (tensor<128x128xf32, #blocked>) { - %76 = ttg.async_wait %arg14 {num = 2 : i32} - %77 = ttg.memdesc_subview %3[%39, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.local_load %77 token %76 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> - %79 = ttg.async_wait %arg16 {num = 2 : i32} - %80 = ttg.memdesc_subview %4[%39, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %81 = ttg.local_load %80 token %79 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> - %82 = arith.addf %78, %81 : tensor<128x128xf32, #blocked> - scf.yield %82 : tensor<128x128xf32, #blocked> - } else { - scf.yield %0 : tensor<128x128xf32, #blocked> - } - %41 = scf.if %arg12 -> (tensor<128x128xf32, #blocked>) { - scf.yield %0 : tensor<128x128xf32, #blocked> - } else { - %76 = ttg.async_wait %arg18 {num = 2 : i32} - %77 = ttg.memdesc_subview %5[%39, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.local_load %77 token %76 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> - %79 = arith.addf %78, %cst : tensor<128x128xf32, #blocked> - scf.yield %79 : tensor<128x128xf32, #blocked> - } - %42:2 = scf.if %arg12 -> (tensor<128x128xf32, #blocked>, i32) { - scf.yield %40, %arg20 : tensor<128x128xf32, #blocked>, i32 - } else { - scf.yield %41, %arg22 : tensor<128x128xf32, #blocked>, i32 - } {b} - %43 = ttg.async_wait %arg24 {num = 1 : i32} - %44 = ttg.memdesc_subview %6[%39, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %45 = ttg.local_load %44 token %43 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> - %46 = arith.addf %45, %42#0 : tensor<128x128xf32, #blocked> - %47 = arith.addi %arg6, %c1_i32 : i32 - %48 = tt.addptr %arg7, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %49 = tt.addptr %arg8, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %50 = tt.addptr %arg9, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %51 = arith.addi %arg10, %c1_i32 : i32 - %52 = arith.cmpi slt, %51, %c2_i32 : i32 - %53 = arith.select %52, %51, %c0_i32 : i32 - %54 = arith.cmpi eq, %47, %c0_i32 : i32 - %55 = arith.muli %arg2, %c2_i32 : i32 - %56 = arith.addi %arg5, %55 : i32 - %57 = arith.muli %arg2, %c2_i32 : i32 - %58 = arith.addi %arg5, %57 : i32 - %59 = arith.andi %36, %54 : i1 - %60:3 = scf.if %59 -> (!ttg.async.token, !ttg.async.token, i32) { - %76 = tt.addptr %49, %cst_0 : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> - %77 = ttg.memdesc_subview %3[%53, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.async_copy_global_to_local %76, %77 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %79 = ttg.async_commit_group %78 - %80 = ttg.memdesc_subview %4[%53, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %81 = ttg.async_copy_global_to_local %49, %80 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %82 = ttg.async_commit_group %81 - %83 = arith.addi %56, %58 : i32 - scf.yield %79, %82, %83 : !ttg.async.token, !ttg.async.token, i32 - } else { - scf.yield %2, %2, %1 : !ttg.async.token, !ttg.async.token, i32 - } - %61 = arith.muli %arg2, %c2_i32 : i32 - %62 = arith.addi %arg5, %61 : i32 - %63 = arith.muli %arg2, %c2_i32 : i32 - %64 = arith.addi %arg5, %63 : i32 - %65 = arith.andi %36, %54 : i1 - %66:2 = scf.if %65 -> (!ttg.async.token, i32) { - scf.yield %2, %1 : !ttg.async.token, i32 - } else { - %76 = ttg.memdesc_subview %5[%53, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %77 = ttg.async_copy_global_to_local %49, %76 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.async_commit_group %77 - %79 = arith.addi %62, %64 : i32 - scf.yield %78, %79 : !ttg.async.token, i32 - } - %67 = ttg.memdesc_subview %6[%53, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %68 = tt.splat %36 : i1 -> tensor<128x128xi1, #blocked> - %69 = ttg.async_copy_global_to_local %48, %67 mask %68 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %70 = ttg.async_commit_group %69 - %71 = arith.cmpi eq, %47, %arg1 : i32 - %72 = arith.andi %36, %71 : i1 - %73 = scf.if %72 -> (!ttg.async.token) { - %76 = ttg.memdesc_subview %7[%53, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %77 = ttg.async_copy_global_to_local %50, %76 : tensor<128x128x!tt.ptr, #blocked> -> <128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.async_commit_group %77 - scf.yield %78 : !ttg.async.token - } else { - scf.yield %2 : !ttg.async.token - } - %74 = scf.if %arg26 -> (tensor<128x128xf32, #blocked>) { - %76 = ttg.async_wait %arg28 {num = 2 : i32} - %77 = ttg.memdesc_subview %7[%39, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> - %78 = ttg.local_load %77 token %76 : !ttg.memdesc<128x128xf32, #shared, #smem, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> - scf.yield %78 : tensor<128x128xf32, #blocked> - } else { - scf.yield %0 : tensor<128x128xf32, #blocked> - } - %75 = scf.if %arg26 -> (tensor<128x128xf32, #blocked>) { - scf.yield %74 : tensor<128x128xf32, #blocked> - } else { - scf.yield %cst : tensor<128x128xf32, #blocked> - } {a} - scf.if %arg26 { - %76 = arith.addf %75, %46 : tensor<128x128xf32, #blocked> - tt.store %arg4, %76 : tensor<128x128x!tt.ptr, #blocked> - } - scf.yield %47, %48, %49, %50, %53, %39, %arg13, %54, %arg15, %60#0, %arg17, %60#1, %arg19, %66#0, %arg21, %60#2, %arg23, %66#1, %arg25, %70, %arg27, %71, %arg29, %73 : i32, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, tensor<128x128x!tt.ptr, #blocked>, i32, i32, i1, i1, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, !ttg.async.token, i32, i32, i32, i32, !ttg.async.token, !ttg.async.token, i1, i1, !ttg.async.token, !ttg.async.token - } {tt.num_stages = 3 : i32} - %33 = ttg.async_wait {num = 0 : i32} - ttg.local_dealloc %3 : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - ttg.local_dealloc %4 : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - ttg.local_dealloc %5 : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - ttg.local_dealloc %6 : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - ttg.local_dealloc %7 : !ttg.memdesc<2x128x128xf32, #shared, #smem, mutable> - tt.return - } -} -