-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
codegen error: !resolution.resolution_points_.empty() INTERNAL ASSERT FAILED - Could not resolve persistent buffer #1123
Comments
This bug is still reproducible and now we're hitting it with Riccardo's work on saved tensor recomputation to enable longer sequences and larger parameter count models. @naoyam, who would be the right person to take a look and fix the problem? The fusion definition is a bit shorter and the error message is the same. Here's a script to reproduce the error: # CUDA devices:
# 0: NVIDIA H100 80GB HBM3
# 1: NVIDIA H100 80GB HBM3
# 2: NVIDIA H100 80GB HBM3
# 3: NVIDIA H100 80GB HBM3
# 4: NVIDIA H100 80GB HBM3
# 5: NVIDIA H100 80GB HBM3
# 6: NVIDIA H100 80GB HBM3
# 7: NVIDIA H100 80GB HBM3
# torch version: 2.5.0a0+gitc0436c5
# cuda version: 12.6
# nvfuser version: 0.2.10+gite3e8485
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 8192, 4096], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[1, 8192, 4096], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[8192, 4096], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
T3 = fd.define_tensor(shape=[4096], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T4 = fd.ops.cast(T0, dtype=DataType.Float)
T5 = fd.ops.cast(T1, dtype=DataType.Float)
T6 = fd.ops.add(T5, T4)
S7 = fd.define_scalar(1, dtype=DataType.Int)
S8 = fd.define_scalar(8192, dtype=DataType.Int)
S9 = fd.define_scalar(4096, dtype=DataType.Int)
T11 = fd.ops.reshape(T2, new_shape=[S7, S8, S9])
T12 = fd.ops.cast(T3, dtype=DataType.Float)
T13 = fd.ops.sum(T6, dims=[2], keepdim=False, dtype=DataType.Null)
T14 = fd.ops.cast(T11, dtype=DataType.Float)
T19 = fd.ops.broadcast_in_dim(T12, shape=[1, 8192, 4096], broadcast_dims=[2])
T24 = fd.ops.broadcast_in_dim(T13, shape=[1, 8192, 1], broadcast_dims=[0, 1])
T25 = fd.ops.mul(T19, T14)
T26 = fd.ops.mul(T6, T25)
T27 = fd.ops.sum(T26, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T32 = fd.ops.broadcast_in_dim(T27, shape=[1, 8192, 1], broadcast_dims=[1])
T33 = fd.ops.sum(T32, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T37 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192], broadcast_dims=[1])
T42 = fd.ops.broadcast_in_dim(T37, shape=[1, 8192, 1], broadcast_dims=[0, 1])
T47 = fd.ops.broadcast_in_dim(T42, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2])
T52 = fd.ops.broadcast_in_dim(T24, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2])
T53 = fd.ops.mul(T6, T47)
T54 = fd.ops.mul(T52, T25)
T55 = fd.ops.add(T54, T53)
T56 = fd.ops.cast(T55, dtype=DataType.BFloat16)
S57 = fd.define_scalar(8192, dtype=DataType.Int)
S58 = fd.define_scalar(4096, dtype=DataType.Int)
T60 = fd.ops.reshape(T56, new_shape=[S57, S58])
fd.add_output(T55)
fd.add_output(T60)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 8192, 4096), (33554432, 4096, 1)),
torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 8192, 4096), (33554432, 4096, 1)),
torch.randn(33554432, dtype=torch.bfloat16, device='cuda:0').as_strided((8192, 4096), (4096, 1)),
torch.randn(4096, dtype=torch.bfloat16, device='cuda:0').as_strided((4096,), (1,)),
]
fd.execute(inputs) |
It seems like this part is problematic: T37 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192], broadcast_dims=[1])
T42 = fd.ops.broadcast_in_dim(T37, shape=[1, 8192, 1], broadcast_dims=[0, 1])
T47 = fd.ops.broadcast_in_dim(T42, shape=[1, 8192, 4096], broadcast_dims=[0, 1, 2]) Changing this to the following avoids the error T47 = fd.ops.broadcast_in_dim(T33, shape=[1, 8192, 4096], broadcast_dims=[1]) So it seems the persistent buffer resolution traversal doesn't handle repeated broadcasts properly. |
The error is due to the topology of a persistent segment. I'll look into it. |
Update: Still draft, but both of the original and simplified repros don't result in an error with #2946. |
Closed in #2946 |
…d and backward directions (#2946) Finding resolution points fails with a persistent pattern that has non-straight line dependencies. For example, here's a simplified pattern from the repro of #1123: ![Scratch-31](https://github.com/user-attachments/assets/cac62d10-acee-4f4b-89e2-71c19d7a5557) `T3` is the persistent tensor in this case, and `T7` is the resolution point. While may not be apparent immediately, `T3` cannot be inlined into `T4` or `T9` since that would require `T2` and `T7` and also `T6` and `T5` to be inlined as well, but because of the reduction path, that is not possible. The existing resolution analysis only looks at consumers of the persistent buffer, so it first looks at `T9` but it stops there and fails to detect the resolution point. This PR adds a more thorough analysis that traverses both forward and backward directions. In the above case, after visiting `T9`, it moves on to `T2` and then find the resolution point at `T3`. I originally planned to completely replace the existing analysis with the new one, but looks like it opened a can of worms, so I backed off from the plan and only used the new analysis as a fallback analysis of the existing one. Any fusion that works with the existing analysis should not be affected as the fallback path should not be taken. Only those that would fail with the existing analysis should use the new analysis. A simplified C++ repro extracted from #1123 is also added. Regarding the issues in replacing the existing analysis, there are some small number of test cases where the old and new analyses result in generating different resolution points. Some of them seem to be a matter of what exaction resolution should mean. Since it isn't urgent, I decided not to pursue further. One interesting case, however, was a fusion with a persistent buffer that doesn't actually seem to be necessary to be persistent (see #2954). I'm not sure how important that would be, but I'll leave it as an open issue as well. Closes #1123
Repro python script:
The text was updated successfully, but these errors were encountered: