Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false #17341

Open
Cookiee235 opened this issue Sep 5, 2024 · 3 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@Cookiee235
Copy link
Contributor

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/reduced/complete/1954_test.py", line 252, in <module>
    mod = relax.transform.FuseTIR()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  23: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  22: tvm::transform::Pass::operator()(tvm::IRModule) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  16: tvm::relax::FuseTIR(tvm::IRModule)
  15: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
  14: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
  13: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
  10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  4: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  3: tvm::relax::FusedTIRConstructor::VisitBinding_(tvm::relax::VarBindingNode const*)
  2: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  1: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  0: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::CallNode const*)
  File "/software/tvm/src/relax/transform/fuse_tir.cc", line 527
InternalError: Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false: Only schedulable functions (whose body is the root block) can be fused

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def concatenate(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(3211254),), "float32"), T_concat: T.Buffer((T.int64(3211264),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211264)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(3211264), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def concatenate1(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(12534),), "float32"), T_concat: T.Buffer((T.int64(12544),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12544)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(12544), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def concatenate2(tensor_1dim: T.Buffer((T.int64(6),), "float32"), pad_tensor: T.Buffer((T.int64(2),), "float32"), T_concat: T.Buffer((T.int64(8),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(8)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(8), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])
                T.writes(T_concat[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(6) <= v_ax0, pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])

    @T.prim_func(private=True)
    def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")):
        T.evaluate(0)

    @T.prim_func
    def exp_8(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def layer_norm_2(x: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32"), gamma: T.Buffer((T.int64(112), T.int64(112)), "float32"), beta: T.Buffer((T.int64(112), T.int64(112)), "float32"), T_layer_norm: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        x_red_temp_v0 = T.alloc_buffer((T.int64(4), T.int64(64)))
        x_red_temp_v1 = T.alloc_buffer((T.int64(4), T.int64(64)))
        for ax0, ax1, k2, k3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("x_red_temp"):
                v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
                T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
                T.writes(x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1])
                with T.init():
                    x_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
                    x_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
                v_x_red_temp_v0: T.float32 = x_red_temp_v0[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3]
                v_x_red_temp_v1: T.float32 = x_red_temp_v1[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3] * x[v_ax0, v_ax1, v_k2, v_k3]
                x_red_temp_v0[v_ax0, v_ax1] = v_x_red_temp_v0
                x_red_temp_v1[v_ax0, v_ax1] = v_x_red_temp_v1
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_layer_norm"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
                T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
                T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, v_ax1, v_ax2, v_ax3] - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) * T.rsqrt(x_red_temp_v1[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) * (x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]

    @T.prim_func
    def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def relu_10(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func
    def reshape_6(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):
        T.evaluate(0)

    @T.prim_func(private=True)
    def reshape_62(gv: T.Buffer((T.int64(10),), "float32"), T_reshape: T.Buffer((T.int64(10),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(10)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(10), ax0)
                T.reads(gv[v_ax0 % T.int64(10)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = gv[v_ax0 % T.int64(10)]

    @T.prim_func(private=True)
    def reshape_63(temp: T.Buffer((T.int64(3211264),), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)]

    @T.prim_func(private=True)
    def reshape_64(temp: T.Buffer((T.int64(12544),), "float32"), T_reshape: T.Buffer((T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)]

    @T.prim_func(private=True)
    def reshape_66(temp: T.Buffer((T.int64(8),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)]

    @T.prim_func(private=True)
    def reshape_67(b: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_reshape: T.Buffer((T.int64(6),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(6)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(6), ax0)
                T.reads(b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)]

    @T.prim_func
    def tir_matmul(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i0, j0, k0 in T.grid(32, 32, 32):
            with T.block(""):
                i, j, k = T.axis.remap("SSR", [i0, j0, k0])
                T.reads(A[i, k], B[j, k])
                T.writes(C[i, j])
                with T.init():
                    C[i, j] = T.float32(0)
                C[i, j] = C[i, j] + A[i, k] * B[j, k]

    @T.prim_func
    def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i, j in T.grid(32, 32):
            with T.block(""):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = T.max(A[vi, vj], T.float32(0))

    @T.prim_func(private=True)
    def zeros(T_full: T.Buffer((T.int64(3211254),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211254)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(3211254), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @T.prim_func(private=True)
    def zeros1(T_full: T.Buffer((T.int64(12534),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12534)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(12534), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @T.prim_func(private=True)
    def zeros2(T_full: T.Buffer((T.int64(2),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(2), ax0)
                T.reads()
                T.writes(T_full[v_ax0])
                T_full[v_ax0] = T.float32(0)

    @R.function
    def main_0(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            tensor_1dim = R.call_tir(cls.reshape_67, (b,), out_sinfo=R.Tensor((6,), dtype="float32"))
            pad_tensor = R.call_tir(cls.zeros2, R.tuple(), out_sinfo=R.Tensor((2,), dtype="float32"))
            temp = R.call_tir(cls.concatenate2, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((8,), dtype="float32"))
            para0 = R.call_tir(cls.reshape_66, (temp,), out_sinfo=R.Tensor((2, 4), dtype="float32"))
            res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0(para0)
            R.output(res)
        return res

    @R.function
    def main_0_9_0(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        R.func_attr({"relax.force_pure": 1})
        cls = Module
        alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global"))
        gv: R.Tensor((10,), dtype="float32") = alloc4
        tensor_1dim = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((3211254,), dtype="float32"))
        temp = R.call_tir(cls.concatenate, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((3211264,), dtype="float32"))
        para0 = R.call_tir(cls.reshape_63, (temp,), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
        tensor_1dim_1 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_1 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_1 = R.call_tir(cls.concatenate1, (tensor_1dim_1, pad_tensor_1), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para1 = R.call_tir(cls.reshape_64, (temp_1,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        tensor_1dim_2 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_2 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_2 = R.call_tir(cls.concatenate1, (tensor_1dim_2, pad_tensor_2), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para2 = R.call_tir(cls.reshape_64, (temp_2,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0_2(para0, para1, para2)
        return res

    @R.function
    def main_0_9_0_2(x: R.Tensor((4, 64, 112, 112), dtype="float32"), gamma: R.Tensor((112, 112), dtype="float32"), beta: R.Tensor((112, 112), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            ln = R.call_tir(cls.layer_norm_2, (x, gamma, beta), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
            R.output(ln)
        return ln

mod = Module
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)
@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Sep 5, 2024
@Cookiee235
Copy link
Contributor Author

Cookiee235 commented Sep 5, 2024

This bug only can be triggered by the given pass sequence (i.e., [AnnotateTIROpPattern, FuseOps, FuseTIR]). It seems that the Relax variable was incorrectly mapped to two different TIR buffer objects.

@Lunderberg Could you help me review it? Thanks a lot!

@Lunderberg
Copy link
Contributor

Hmm. I can't reproduce your error from the test case. Instead, I end up running into Check failed: (StructuralEqual()((*it).second, new_buf)) is false: Inconsistent buffers B and b mapped to the same relax var: b

@Cookiee235
Copy link
Contributor Author

Cookiee235 commented Sep 6, 2024

@Lunderberg You are right! The above crash messages are thrown when using an earlier TVM version (i.e., '0.17.dev0'). When I switched TVM to a later version (i.e., 0.18.dev0), the test crashed and reported the same crash message as you showed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants