Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false #17341

Cookiee235 opened this issue Sep 5, 2024 · 3 comments
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug


Copy link

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/reduced/complete/", line 252, in <module>
    mod = relax.transform.FuseTIR()(mod)
  File "/software/tvm/python/tvm/ir/", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/software/tvm/python/tvm/_ffi/_ctypes/", line 239, in __call__
  File "/software/tvm/python/tvm/_ffi/", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  23: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  22: tvm::transform::Pass::operator()(tvm::IRModule) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform7FuseTIREvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  16: tvm::relax::FuseTIR(tvm::IRModule)
  15: tvm::relax::TIRFuseMutator::Transform(tvm::IRModule)
  14: tvm::relax::FusedTIRConstructor::GetFusedTIR(tvm::IRModule const&, tvm::GlobalVar const&)
  13: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  11: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::FunctionNode const*)
  10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  4: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  3: tvm::relax::FusedTIRConstructor::VisitBinding_(tvm::relax::VarBindingNode const*)
  2: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  1: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  0: tvm::relax::FusedTIRConstructor::VisitExpr_(tvm::relax::CallNode const*)
  File "/software/tvm/src/relax/transform/", line 527
InternalError: Check failed: (prim_func->body->IsInstance<tir::BlockRealizeNode>()) is false: Only schedulable functions (whose body is the root block) can be fused

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

class Module:
    def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")):

    def concatenate(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(3211254),), "float32"), T_concat: T.Buffer((T.int64(3211264),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211264)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(3211264), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    def concatenate1(tensor_1dim: T.Buffer((T.int64(10),), "float32"), pad_tensor: T.Buffer((T.int64(12534),), "float32"), T_concat: T.Buffer((T.int64(12544),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12544)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(12544), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(10) <= v_ax0, pad_tensor[v_ax0 - T.int64(10)], tensor_1dim[v_ax0])

    def concatenate2(tensor_1dim: T.Buffer((T.int64(6),), "float32"), pad_tensor: T.Buffer((T.int64(2),), "float32"), T_concat: T.Buffer((T.int64(8),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(8)):
            with T.block("T_concat"):
                v_ax0 = T.axis.spatial(T.int64(8), ax0)
                T.reads(pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])
                T_concat[v_ax0] = T.if_then_else(T.int64(6) <= v_ax0, pad_tensor[v_ax0 - T.int64(6)], tensor_1dim[v_ax0])

    def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")):

    def exp_8(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):

    def layer_norm_2(x: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32"), gamma: T.Buffer((T.int64(112), T.int64(112)), "float32"), beta: T.Buffer((T.int64(112), T.int64(112)), "float32"), T_layer_norm: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        x_red_temp_v0 = T.alloc_buffer((T.int64(4), T.int64(64)))
        x_red_temp_v1 = T.alloc_buffer((T.int64(4), T.int64(64)))
        for ax0, ax1, k2, k3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("x_red_temp"):
                v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3])
                T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
                T.writes(x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1])
                with T.init():
                    x_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
                    x_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
                v_x_red_temp_v0: T.float32 = x_red_temp_v0[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3]
                v_x_red_temp_v1: T.float32 = x_red_temp_v1[v_ax0, v_ax1] + x[v_ax0, v_ax1, v_k2, v_k3] * x[v_ax0, v_ax1, v_k2, v_k3]
                x_red_temp_v0[v_ax0, v_ax1] = v_x_red_temp_v0
                x_red_temp_v1[v_ax0, v_ax1] = v_x_red_temp_v1
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_layer_norm"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_red_temp_v0[v_ax0, v_ax1], x_red_temp_v1[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
                T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
                T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0, v_ax1, v_ax2, v_ax3] - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) * T.rsqrt(x_red_temp_v1[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) - x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05) * (x_red_temp_v0[v_ax0, v_ax1] * T.float32(7.9719387755102034e-05)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]

    def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")):

    def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")):

    def relu_10(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")):

    def reshape_6(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")):

    def reshape_62(gv: T.Buffer((T.int64(10),), "float32"), T_reshape: T.Buffer((T.int64(10),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(10)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(10), ax0)
                T.reads(gv[v_ax0 % T.int64(10)])
                T_reshape[v_ax0] = gv[v_ax0 % T.int64(10)]

    def reshape_63(temp: T.Buffer((T.int64(3211264),), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(64), T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(64), T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = temp[(v_ax0 * T.int64(802816) + v_ax1 * T.int64(12544) + v_ax2 * T.int64(112) + v_ax3) % T.int64(3211264)]

    def reshape_64(temp: T.Buffer((T.int64(12544),), "float32"), T_reshape: T.Buffer((T.int64(112), T.int64(112)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(112), T.int64(112)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(112) + v_ax1) % T.int64(12544)]

    def reshape_66(temp: T.Buffer((T.int64(8),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(2), T.int64(4)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = temp[(v_ax0 * T.int64(4) + v_ax1) % T.int64(8)]

    def reshape_67(b: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_reshape: T.Buffer((T.int64(6),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(6)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(T.int64(6), ax0)
                T.reads(b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)])
                T_reshape[v_ax0] = b[v_ax0 % T.int64(6) // T.int64(3), v_ax0 % T.int64(3)]

    def tir_matmul(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i0, j0, k0 in T.grid(32, 32, 32):
            with T.block(""):
                i, j, k = T.axis.remap("SSR", [i0, j0, k0])
                T.reads(A[i, k], B[j, k])
                T.writes(C[i, j])
                with T.init():
                    C[i, j] = T.float32(0)
                C[i, j] = C[i, j] + A[i, k] * B[j, k]

    def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")):
        # with T.block("root"):
        for i, j in T.grid(32, 32):
            with T.block(""):
                vi, vj = T.axis.remap("SS", [i, j])
                T.reads(A[vi, vj])
                T.writes(B[vi, vj])
                B[vi, vj] = T.max(A[vi, vj], T.float32(0))

    def zeros(T_full: T.Buffer((T.int64(3211254),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3211254)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(3211254), ax0)
                T_full[v_ax0] = T.float32(0)

    def zeros1(T_full: T.Buffer((T.int64(12534),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(12534)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(12534), ax0)
                T_full[v_ax0] = T.float32(0)

    def zeros2(T_full: T.Buffer((T.int64(2),), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2)):
            with T.block("T_full"):
                v_ax0 = T.axis.spatial(T.int64(2), ax0)
                T_full[v_ax0] = T.float32(0)

    def main_0(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            a = R.call_tir(cls.exp, (x,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            b = R.call_tir(cls.exp, (a,), out_sinfo=R.Tensor((2, 3), dtype="float32"))
            tensor_1dim = R.call_tir(cls.reshape_67, (b,), out_sinfo=R.Tensor((6,), dtype="float32"))
            pad_tensor = R.call_tir(cls.zeros2, R.tuple(), out_sinfo=R.Tensor((2,), dtype="float32"))
            temp = R.call_tir(cls.concatenate2, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((8,), dtype="float32"))
            para0 = R.call_tir(cls.reshape_66, (temp,), out_sinfo=R.Tensor((2, 4), dtype="float32"))
            res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0(para0)
        return res

    def main_0_9_0(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        R.func_attr({"relax.force_pure": 1})
        cls = Module
        alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global"))
        gv: R.Tensor((10,), dtype="float32") = alloc4
        tensor_1dim = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor = R.call_tir(cls.zeros, R.tuple(), out_sinfo=R.Tensor((3211254,), dtype="float32"))
        temp = R.call_tir(cls.concatenate, (tensor_1dim, pad_tensor), out_sinfo=R.Tensor((3211264,), dtype="float32"))
        para0 = R.call_tir(cls.reshape_63, (temp,), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
        tensor_1dim_1 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_1 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_1 = R.call_tir(cls.concatenate1, (tensor_1dim_1, pad_tensor_1), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para1 = R.call_tir(cls.reshape_64, (temp_1,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        tensor_1dim_2 = R.call_tir(cls.reshape_62, (gv,), out_sinfo=R.Tensor((10,), dtype="float32"))
        pad_tensor_2 = R.call_tir(cls.zeros1, R.tuple(), out_sinfo=R.Tensor((12534,), dtype="float32"))
        temp_2 = R.call_tir(cls.concatenate1, (tensor_1dim_2, pad_tensor_2), out_sinfo=R.Tensor((12544,), dtype="float32"))
        para2 = R.call_tir(cls.reshape_64, (temp_2,), out_sinfo=R.Tensor((112, 112), dtype="float32"))
        res: R.Tensor((4, 64, 112, 112), dtype="float32") = cls.main_0_9_0_2(para0, para1, para2)
        return res

    def main_0_9_0_2(x: R.Tensor((4, 64, 112, 112), dtype="float32"), gamma: R.Tensor((112, 112), dtype="float32"), beta: R.Tensor((112, 112), dtype="float32")) -> R.Tensor((4, 64, 112, 112), dtype="float32"):
        cls = Module
        with R.dataflow():
            ln = R.call_tir(cls.layer_norm_2, (x, gamma, beta), out_sinfo=R.Tensor((4, 64, 112, 112), dtype="float32"))
        return ln

mod = Module
mod = relax.transform.AnnotateTIROpPattern()(mod)
mod = relax.transform.FuseOps()(mod)
mod = relax.transform.FuseTIR()(mod)
@Cookiee235 Cookiee235 added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Sep 5, 2024
Copy link
Contributor Author

Cookiee235 commented Sep 5, 2024

This bug only can be triggered by the given pass sequence (i.e., [AnnotateTIROpPattern, FuseOps, FuseTIR]). It seems that the Relax variable was incorrectly mapped to two different TIR buffer objects.

@Lunderberg Could you help me review it? Thanks a lot!

Copy link

Hmm. I can't reproduce your error from the test case. Instead, I end up running into Check failed: (StructuralEqual()((*it).second, new_buf)) is false: Inconsistent buffers B and b mapped to the same relax var: b

Copy link
Contributor Author

Cookiee235 commented Sep 6, 2024

@Lunderberg You are right! The above crash messages are thrown when using an earlier TVM version (i.e., '0.17.dev0'). When I switched TVM to a later version (i.e., 0.18.dev0), the test crashed and reported the same crash message as you showed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
None yet

No branches or pull requests

2 participants