Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BarrierOp in AssertOp lowering can cause deadlock #5632

Open
kishore-ganesh opened this issue Jan 17, 2025 · 5 comments
Open

BarrierOp in AssertOp lowering can cause deadlock #5632

kishore-ganesh opened this issue Jan 17, 2025 · 5 comments
Labels

Comments

@kishore-ganesh
Copy link
Contributor

kishore-ganesh commented Jan 17, 2025

Describe the bug

Currently, in the lowering of AssertOp, a BarrierOp is present in the end .

However, this is malformed if the AssertOp is part of a basic block that is executed by only some threads, since it is incorrect to have a BarrierOp/__syncthreads on a divergent block ().

As an example, I have attached a TTIR file with an AssertOp under the region of a ReduceOp.

    %25 = "tt.reduce"(%24) <{axis = 1 : i32}> ({
    ^bb0(%arg4: i32 loc("specimen_2.ttir":47:10), %arg5: i32 loc("specimen_2.ttir":47:22)):
      %33 = arith.cmpi sgt, %arg4, %arg5 : i32 loc(#loc52)
      %34 = arith.extsi %arg4 : i32 to i64 loc(#loc53)
      %35 = arith.cmpi sle, %34, %c2147483647_i64 : i64 loc(#loc54)
      %36 = arith.cmpi sge, %34, %c-2147483648_i64 : i64 loc(#loc55)
      %37 = arith.andi %35, %36 : i1 loc(#loc56)
      %38 = tt.splat %37 : i1 -> tensor<1xi1, #blocked> loc(#loc57)
      tt.assert %38, "int32 overflow detected for operation add" : tensor<1xi1, #blocked> loc(#loc58)
      %39 = arith.select %33, %arg4, %arg5 : i32 loc(#loc59)
      tt.reduce.return %39 : i32 loc(#loc60))

The ReduceOp lowering predicates this basic block on a thread varying condition:

Value threadIsNeeded = icmp_slt(threadId, i32_val(elems));

This can cause a deadlock since all threads may not be present at the site of the BarrierOp

specimen_2.ttir.txt

Environment details

Triton: pytorch-triton package
Version: 3.2.0+git0d4682f0
GPU: A100

@peterbell10
Copy link
Contributor

Is this a manually costructed IR? I don't think we support tensors in reduce regions.

@kishore-ganesh
Copy link
Contributor Author

It was obtained through Inductor. I have attached the extracted kernel to this comment.

On running with:
MLIR_ENABLE_DUMP=1 TRITON_DEBUG=1 python3 specimen.py, similar IR is obtained

specimen.py.txt

@peterbell10
Copy link
Contributor

Okay, that's strange because the code that added a splat before calling tt.assert doesn't exist on that version of triton. Is there any chance you have multiple versions installed? Perhaps try a clean reinstall.

@kishore-ganesh
Copy link
Contributor Author

Are you referring to 92a4fad? That's the change where Splat before assert was removed

However, in this case, the Splat before the AssertOp is the result of ReorderBroadcast pass.

Before:


    %23 = "tt.reduce"(%22) <{axis = 1 : i32}> ({
    ^bb0(%arg4: i32 loc(callsite(#loc1 at #loc12)), %arg5: i32 loc(callsite(#loc1 at #loc12))):
      %32 = arith.cmpi slt, %arg4, %arg5 : i32 loc(#loc40)
      %33 = tt.splat %arg4 : i32 -> tensor<1xi32> loc(#loc48)
      %34 = arith.extsi %33 : tensor<1xi32> to tensor<1xi64> loc(#loc48)
      %35 = arith.cmpi sle, %34, %cst_2 : tensor<1xi64> loc(#loc48)
      %36 = arith.cmpi sge, %34, %cst_3 : tensor<1xi64> loc(#loc48)
      %37 = arith.andi %35, %36 : tensor<1xi1> loc(#loc48)
      tt.assert %37, "int32 overflow detected for operation add" : tensor<1xi1> loc(#loc48)
      %38 = arith.select %32, %arg4, %arg5 : i32 loc(#loc42)
      tt.reduce.return %38 : i32 loc(#loc31)
    }) : (tensor<1x128xi32>) -> tensor<1xi32> loc(#loc31)

After:


    ^bb0(%arg4: i32 loc(callsite(#loc1 at #loc12)), %arg5: i32 loc(callsite(#loc1 at #loc12))):
      %32 = arith.cmpi slt, %arg4, %arg5 : i32 loc(#loc40)
      %33 = arith.extsi %arg4 : i32 to i64 loc(#loc48)
      %34 = arith.cmpi sle, %33, %c2147483647_i64 : i64 loc(#loc48)
      %35 = arith.cmpi sge, %33, %c-2147483648_i64 : i64 loc(#loc48)
      %36 = arith.andi %34, %35 : i1 loc(#loc48)
      %37 = tt.splat %36 : i1 -> tensor<1xi1> loc(#loc48)
      tt.assert %37, "int32 overflow detected for operation add" : tensor<1xi1> loc(#loc48)
      %38 = arith.select %32, %arg4, %arg5 : i32 loc(#loc42)
      tt.reduce.return %38 : i32 loc(#loc31)
    }) : (tensor<1x128xi32>) -> tensor<1xi32> loc(#loc31)

I think it's due to:

// elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) struct MoveSplatAfterElementwisePattern

@peterbell10
Copy link
Contributor

Oh I see, the tensor is coming from here:
https://github.com/pytorch/pytorch/blob/46fbd63405c7dae2efc9dbfb8fa44e85313f0051/torch/_inductor/runtime/triton_helpers.py#L45

The good news is the assert cannot fail, and gets optimized out by llvm. We should probably have a verifier that there are no tensor ops inside reduce or scan regions, but it's not really a priority atm.

cc @davidberard98 for the overflow sanitizer running at all, should have been fixed by pytorch/pytorch#139502

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants