Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen error: !resolution.resolution_points_.empty() INTERNAL ASSERT FAILED - Could not resolve persistent buffer #1123

Closed
jjsjann123 opened this issue Oct 21, 2023 · 5 comments · Fixed by #2946
Assignees
Labels
bug Something isn't working Thunder

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Oct 21, 2023

RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/scheduler/utils.cpp":370, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Could not resolve persistent buffer: 0x7f7f2d34da40
Exception raised from getResolutionPointsOf at /opt/pytorch/nvfuser/csrc/scheduler/utils.cpp:370 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f81c187c17f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7f81c1be1343 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::scheduler_utils::persistentBuffers(nvfuser::Fusion*) + 0x1634 (0x7f81c1f814e4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x7f3f50 (0x7f81c1f3ef50 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x7fe187 (0x7f81c1f49187 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x7ff0df (0x7f81c1f4a0df in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x544b67 (0x7f81c1c8fb67 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x544ebb (0x7f81c1c8febb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x545070 (0x7f81c1c90070 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x558775 (0x7f81c1ca3775 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::SegmentCandidateFinder::SegmentCandidateFinder(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const*, nvfuser::SegmentCandidateFinderOptions) + 0x46f (0x7f81c1ca3cef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x4ef402 (0x7f81c1c3a402 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x558efe (0x7f81c1ca3efe in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x6cd374 (0x7f81c1e18374 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x6ce85e (0x7f81c1e1985e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xa9 (0x7f81c1e1ce89 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #16: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x394 (0x7f81c201d284 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)

Repro python script:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id7(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, None, True, True], dtype=DataType.Bool, is_cpu=False)
    T2 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False)
    T3 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False)
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False)
    S5 = fd.define_scalar(float("-inf"), dtype=DataType.Double)
    T6 = fd.ops.where(T1, T0, S5)
    T7 = fd.ops.cast(T3, dtype=DataType.Float)
    T8 = fd.ops.cast(T6, dtype=DataType.Float)
    T9 = fd.ops.max(T8, dims=[3], keepdim=False, dtype=DataType.Null)
    S10 = fd.define_scalar(16, dtype=DataType.Int)
    S11 = fd.define_scalar(16, dtype=DataType.Int)
    S12 = fd.define_scalar(128, dtype=DataType.Int)
    S13 = fd.define_scalar(1, dtype=DataType.Int)
    V14 = fd.define_vector([S10, S11, S12, S13], dtype=DataType.Int)
    T15 = fd.ops.broadcast_in_dim(T9, shape=V14, broadcast_dims=[0, 1, 2])
    S16 = fd.define_scalar(16, dtype=DataType.Int)
    S17 = fd.define_scalar(16, dtype=DataType.Int)
    S18 = fd.define_scalar(128, dtype=DataType.Int)
    S19 = fd.define_scalar(1, dtype=DataType.Int)
    V20 = fd.define_vector([S16, S17, S18, S19], dtype=DataType.Int)
    T21 = fd.ops.broadcast_in_dim(T9, shape=V20, broadcast_dims=[0, 1, 2])
    T22 = fd.ops.cast(T2, dtype=DataType.Float)
    S23 = fd.define_scalar(16, dtype=DataType.Int)
    S24 = fd.define_scalar(16, dtype=DataType.Int)
    S25 = fd.define_scalar(128, dtype=DataType.Int)
    S26 = fd.define_scalar(128, dtype=DataType.Int)
    V27 = fd.define_vector([S23, S24, S25, S26], dtype=DataType.Int)
    T28 = fd.ops.broadcast_in_dim(T15, shape=V27, broadcast_dims=[0, 1, 2, 3])
    T29 = fd.ops.eq(T8, T28)
    T30 = fd.ops.sum(T29, dims=[3], keepdim=False, dtype=DataType.Null)
    S31 = fd.define_scalar(16, dtype=DataType.Int)
    S32 = fd.define_scalar(16, dtype=DataType.Int)
    S33 = fd.define_scalar(128, dtype=DataType.Int)
    S34 = fd.define_scalar(1, dtype=DataType.Int)
    V35 = fd.define_vector([S31, S32, S33, S34], dtype=DataType.Int)
    T36 = fd.ops.broadcast_in_dim(T30, shape=V35, broadcast_dims=[0, 1, 2])
    S37 = fd.define_scalar(16, dtype=DataType.Int)
    S38 = fd.define_scalar(16, dtype=DataType.Int)
    S39 = fd.define_scalar(128, dtype=DataType.Int)
    S40 = fd.define_scalar(128, dtype=DataType.Int)
    V41 = fd.define_vector([S37, S38, S39, S40], dtype=DataType.Int)
    T42 = fd.ops.broadcast_in_dim(T21, shape=V41, broadcast_dims=[0, 1, 2, 3])
    T43 = fd.ops.sub(T8, T42)
    T44 = fd.ops.exp(T43)
    T45 = fd.ops.sum(T44, dims=[3], keepdim=False, dtype=DataType.Null)
    S46 = fd.define_scalar(16, dtype=DataType.Int)
    S47 = fd.define_scalar(16, dtype=DataType.Int)
    S48 = fd.define_scalar(128, dtype=DataType.Int)
    S49 = fd.define_scalar(1, dtype=DataType.Int)
    V50 = fd.define_vector([S46, S47, S48, S49], dtype=DataType.Int)
    T51 = fd.ops.broadcast_in_dim(T45, shape=V50, broadcast_dims=[0, 1, 2])
    S52 = fd.define_scalar(16, dtype=DataType.Int)
    S53 = fd.define_scalar(16, dtype=DataType.Int)
    S54 = fd.define_scalar(128, dtype=DataType.Int)
    S55 = fd.define_scalar(128, dtype=DataType.Int)
    V56 = fd.define_vector([S52, S53, S54, S55], dtype=DataType.Int)
    T57 = fd.ops.broadcast_in_dim(T51, shape=V56, broadcast_dims=[0, 1, 2, 3])
    S58 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T59 = fd.ops.mul(S58, T7)
    T60 = fd.ops.cast(T4, dtype=DataType.Float)
    T61 = fd.ops.reciprocal(T57)
    T62 = fd.ops.mul(T44, T61)
    T63 = fd.ops.reciprocal(T57)
    T64 = fd.ops.mul(T62, T63)
    T65 = fd.ops.mul(T22, T59)
    T66 = fd.ops.mul(T62, T22)
    T67 = fd.ops.reciprocal(T57)
    T68 = fd.ops.mul(T65, T67)
    T69 = fd.ops.neg(T65)
    T70 = fd.ops.mul(T69, T64)
    T71 = fd.ops.sum(T70, dims=[3], keepdim=False, dtype=DataType.Null)
    S72 = fd.define_scalar(16, dtype=DataType.Int)
    S73 = fd.define_scalar(16, dtype=DataType.Int)
    S74 = fd.define_scalar(128, dtype=DataType.Int)
    S75 = fd.define_scalar(1, dtype=DataType.Int)
    V76 = fd.define_vector([S72, S73, S74, S75], dtype=DataType.Int)
    T77 = fd.ops.broadcast_in_dim(T71, shape=V76, broadcast_dims=[0, 1, 2])
    T78 = fd.ops.sum(T77, dims=[3], keepdim=False, dtype=DataType.Null)
    S79 = fd.define_scalar(16, dtype=DataType.Int)
    S80 = fd.define_scalar(16, dtype=DataType.Int)
    S81 = fd.define_scalar(128, dtype=DataType.Int)
    S82 = fd.define_scalar(1, dtype=DataType.Int)
    V83 = fd.define_vector([S79, S80, S81, S82], dtype=DataType.Int)
    T84 = fd.ops.broadcast_in_dim(T78, shape=V83, broadcast_dims=[0, 1, 2])
    S85 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T86 = fd.ops.mul(T66, S85)
    T87 = fd.ops.cast(T86, dtype=DataType.BFloat16)
    S88 = fd.define_scalar(16, dtype=DataType.Int)
    S89 = fd.define_scalar(16, dtype=DataType.Int)
    S90 = fd.define_scalar(128, dtype=DataType.Int)
    S91 = fd.define_scalar(128, dtype=DataType.Int)
    V92 = fd.define_vector([S88, S89, S90, S91], dtype=DataType.Int)
    T93 = fd.ops.broadcast_in_dim(T84, shape=V92, broadcast_dims=[0, 1, 2, 3])
    T94 = fd.ops.add(T93, T68)
    T95 = fd.ops.mul(T94, T44)
    T96 = fd.ops.neg(T95)
    T97 = fd.ops.sum(T96, dims=[3], keepdim=False, dtype=DataType.Null)
    S98 = fd.define_scalar(16, dtype=DataType.Int)
    S99 = fd.define_scalar(16, dtype=DataType.Int)
    S100 = fd.define_scalar(128, dtype=DataType.Int)
    S101 = fd.define_scalar(1, dtype=DataType.Int)
    V102 = fd.define_vector([S98, S99, S100, S101], dtype=DataType.Int)
    T103 = fd.ops.broadcast_in_dim(T97, shape=V102, broadcast_dims=[0, 1, 2])
    T104 = fd.ops.sum(T103, dims=[3], keepdim=False, dtype=DataType.Null)
    S105 = fd.define_scalar(16, dtype=DataType.Int)
    S106 = fd.define_scalar(16, dtype=DataType.Int)
    S107 = fd.define_scalar(128, dtype=DataType.Int)
    S108 = fd.define_scalar(1, dtype=DataType.Int)
    V109 = fd.define_vector([S105, S106, S107, S108], dtype=DataType.Int)
    T110 = fd.ops.broadcast_in_dim(T104, shape=V109, broadcast_dims=[0, 1, 2])
    S111 = fd.define_scalar(16, dtype=DataType.Int)
    S112 = fd.define_scalar(16, dtype=DataType.Int)
    S113 = fd.define_scalar(128, dtype=DataType.Int)
    S114 = fd.define_scalar(128, dtype=DataType.Int)
    V115 = fd.define_vector([S111, S112, S113, S114], dtype=DataType.Int)
    T116 = fd.ops.broadcast_in_dim(T110, shape=V115, broadcast_dims=[0, 1, 2, 3])
    T117 = fd.ops.cast(T29, dtype=DataType.Float)
    T118 = fd.ops.mul(T116, T117)
    S119 = fd.define_scalar(16, dtype=DataType.Int)
    S120 = fd.define_scalar(16, dtype=DataType.Int)
    S121 = fd.define_scalar(128, dtype=DataType.Int)
    S122 = fd.define_scalar(128, dtype=DataType.Int)
    V123 = fd.define_vector([S119, S120, S121, S122], dtype=DataType.Int)
    T124 = fd.ops.broadcast_in_dim(T36, shape=V123, broadcast_dims=[0, 1, 2, 3])
    T125 = fd.ops.cast(T124, dtype=DataType.Float)
    T126 = fd.ops.reciprocal(T125)
    T127 = fd.ops.mul(T118, T126)
    T128 = fd.ops.add(T127, T95)
    T129 = fd.ops.cast(T128, dtype=DataType.BFloat16)
    S130 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T131 = fd.ops.where(T1, T129, S130)
    fd.add_output(T131)
    fd.add_output(T87)
    fd.add_output(T60)

with FusionDefinition() as fd:
    nvfuser_fusion_id7(fd)

inputs = [
    torch.randn((4194304,), dtype=torch.bfloat16, device='cuda').as_strided((16, 16, 128, 128), (262144, 16384, 128, 1)),
    torch.randint(0, 1, (16384,), dtype=torch.bool, device='cuda').as_strided((16, 16, 128, 128), (0, 0, 128, 1)),
    torch.randint(0, 1, (4194304,), dtype=torch.bool, device='cuda').as_strided((16, 16, 128, 128), (262144, 16384, 128, 1)),
    torch.randn((4194304,), dtype=torch.bfloat16, device='cuda').as_strided((16, 16, 128, 128), (262144, 16384, 128, 1)),
    torch.randn((8388608,), dtype=torch.bfloat16, device='cuda').as_strided((16, 128, 4096), (524288, 4096, 1)),
]
fd.execute(inputs)
@jjsjann123 jjsjann123 added the bug Something isn't working label Oct 21, 2023
@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Sep 12, 2024

This bug is still reproducible and now we're hitting it with Riccardo's work on saved tensor recomputation to enable longer sequences and larger parameter count models. @naoyam, who would be the right person to take a look and fix the problem?

The fusion definition is a bit shorter and the error message is the same. Here's a script to reproduce the error:

# CUDA devices:
#  0: NVIDIA H100 80GB HBM3
#  1: NVIDIA H100 80GB HBM3
#  2: NVIDIA H100 80GB HBM3
#  3: NVIDIA H100 80GB HBM3
#  4: NVIDIA H100 80GB HBM3
#  5: NVIDIA H100 80GB HBM3
#  6: NVIDIA H100 80GB HBM3
#  7: NVIDIA H100 80GB HBM3
# torch version: 2.5.0a0+gitc0436c5
# cuda version: 12.6
# nvfuser version: 0.2.10+gite3e8485
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 8192, 4096], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 8192, 4096], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[8192, 4096], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T3 = fd.define_tensor(shape=[4096], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.ops.cast(T0, dtype=DataType.Float)
    T5 = fd.ops.cast(T1, dtype=DataType.Float)
    T6 = fd.ops.add(T5, T4)
    S7 = fd.define_scalar(1, dtype=DataType.Int)
    S8 = fd.define_scalar(8192, dtype=DataType.Int)
    S9 = fd.define_scalar(4096, dtype=DataType.Int)
    T11 = fd.ops.reshape(T2, new_shape=[S7, S8, S9])
    T12 = fd.ops.cast(T3, dtype=DataType.Float)
    T13 = fd.ops.sum(T6, dims=[2], keepdim=False, dtype=DataType.Null)
    T14 = fd.ops.cast(T11, dtype=DataType.Float)
    T19 = fd.ops.broadcast_in_dim(T12, shape=[1, 8192, 4096], broadcast_dims=[2])
    T24 = fd.ops.broadcast_in_dim(T13, shape=[1, 8192, 1], broadcast_dims=[0, 1])
    T25 = fd.ops.mul(T19, T14)
    T26 = fd.ops.mul(T6, T25)
    T27 = fd.ops.sum(T26, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T32 = fd.ops.broadcast_in_dim(T27, shape=[1, 8192, 1], broadcast_dims=[1])
    T33 = fd.ops.sum(T32, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T37 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192], broadcast_dims=[1])
    T42 = fd.ops.broadcast_in_dim(T37, shape=[1, 8192, 1], broadcast_dims=[0, 1])
    T47 = fd.ops.broadcast_in_dim(T42, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2])
    T52 = fd.ops.broadcast_in_dim(T24, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2])
    T53 = fd.ops.mul(T6, T47)
    T54 = fd.ops.mul(T52, T25)
    T55 = fd.ops.add(T54, T53)
    T56 = fd.ops.cast(T55, dtype=DataType.BFloat16)
    S57 = fd.define_scalar(8192, dtype=DataType.Int)
    S58 = fd.define_scalar(4096, dtype=DataType.Int)
    T60 = fd.ops.reshape(T56, new_shape=[S57, S58])
    fd.add_output(T55)
    fd.add_output(T60)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 8192, 4096), (33554432, 4096, 1)),
    torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 8192, 4096), (33554432, 4096, 1)),
    torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((8192, 4096), (4096, 1)),
    torch.randn(4096, dtype=torch.bfloat16, device='cuda:0').as_strided((4096,), (1,)),
]
fd.execute(inputs)

@jacobhinkle
Copy link
Collaborator

It seems like this part is problematic:

    T37 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192], broadcast_dims=[1])
    T42 = fd.ops.broadcast_in_dim(T37, shape=[1, 8192, 1], broadcast_dims=[0, 1])
    T47 = fd.ops.broadcast_in_dim(T42, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2])

Changing this to the following avoids the error

    T47 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192, 4096], broadcast_dims=[1])

So it seems the persistent buffer resolution traversal doesn't handle repeated broadcasts properly.

@naoyam naoyam self-assigned this Sep 16, 2024
@naoyam
Copy link
Collaborator

naoyam commented Sep 16, 2024

The error is due to the topology of a persistent segment. I'll look into it.

naoyam added a commit that referenced this issue Sep 16, 2024
@naoyam
Copy link
Collaborator

naoyam commented Sep 16, 2024

Update: Still draft, but both of the original and simplified repros don't result in an error with #2946.

@naoyam
Copy link
Collaborator

naoyam commented Sep 19, 2024

Closed in #2946

@naoyam naoyam closed this as completed Sep 19, 2024
naoyam added a commit that referenced this issue Sep 19, 2024
…d and backward directions (#2946)

Finding resolution points fails with a persistent pattern that has
non-straight line dependencies. For example, here's a simplified pattern
from the repro of #1123:


![Scratch-31](https://github.com/user-attachments/assets/cac62d10-acee-4f4b-89e2-71c19d7a5557)

`T3` is the persistent tensor in this case, and `T7` is the resolution
point. While may not be apparent immediately, `T3` cannot be inlined
into `T4` or `T9` since that would require `T2` and `T7` and also `T6`
and `T5` to be inlined as well, but because of the reduction path, that
is not possible.

The existing resolution analysis only looks at consumers of the
persistent buffer, so it first looks at `T9` but it stops there and
fails to detect the resolution point.

This PR adds a more thorough analysis that traverses both forward and
backward directions. In the above case, after visiting `T9`, it moves on
to `T2` and then find the resolution point at `T3`.

I originally planned to completely replace the existing analysis with
the new one, but looks like it opened a can of worms, so I backed off
from the plan and only used the new analysis as a fallback analysis of
the existing one. Any fusion that works with the existing analysis
should not be affected as the fallback path should not be taken. Only
those that would fail with the existing analysis should use the new
analysis. A simplified C++ repro extracted from #1123 is also added.

Regarding the issues in replacing the existing analysis, there are some
small number of test cases where the old and new analyses result in
generating different resolution points. Some of them seem to be a matter
of what exaction resolution should mean. Since it isn't urgent, I
decided not to pursue further. One interesting case, however, was a
fusion with a persistent buffer that doesn't actually seem to be
necessary to be persistent (see #2954). I'm not sure how important that
would be, but I'll leave it as an open issue as well.

Closes #1123
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Thunder
Projects
None yet
5 participants