-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Benchmark split + silu + mul
fusion
#2921
base: main
Are you sure you want to change the base?
Conversation
def split_silu_fwd_fusion(fd: FusionDefinition, dtype: DataType, size: tuple): | ||
T0 = fd.define_tensor( | ||
shape=[-1, -1], | ||
contiguity=[True, True], | ||
dtype=dtype, | ||
is_cpu=False, | ||
stride_order=[1, 0], | ||
) | ||
S0, S1 = size | ||
T1 = fd.ops.slice( | ||
T0, start_indices=[0, 0], end_indices=[S0, S1 // 2], strides=[1, 1] | ||
) | ||
T2 = fd.ops.slice( | ||
T0, start_indices=[0, S1 // 2], end_indices=[S0, S1], strides=[1, 1] | ||
) | ||
if dtype in PROMOTE_DTYPES: | ||
T1 = fd.ops.cast(T1, dtype=DataType.Float) | ||
T2 = fd.ops.cast(T2, dtype=DataType.Float) | ||
T3 = fd.ops.neg(T1) | ||
T4 = fd.ops.exp(T3) | ||
S5 = fd.define_scalar(1.00000, dtype=DataType.Double) | ||
T6 = fd.ops.add(S5, T4) | ||
T7 = fd.ops.reciprocal(T6) | ||
T8 = fd.ops.mul(T1, T7) | ||
T9 = fd.ops.mul(T8, T2) | ||
if dtype in PROMOTE_DTYPES: | ||
T9 = fd.ops.cast(T9, dtype=dtype) | ||
fd.add_output(T9) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvanYashchuk Could you confirm if this matches with the case you're looking at?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the forward definition matches the case I was looking at.
def split_silu_bwd_fusion(fd: FusionDefinition, dtype: DataType, size: tuple): | ||
T0 = fd.define_tensor( | ||
shape=[-1, -1], | ||
contiguity=[True, True], | ||
dtype=dtype, | ||
is_cpu=False, | ||
stride_order=[1, 0], | ||
) | ||
T1 = fd.define_tensor( | ||
shape=[-1, -1], | ||
contiguity=[True, True], | ||
dtype=dtype, | ||
is_cpu=False, | ||
stride_order=[1, 0], | ||
) | ||
S0, S1 = size | ||
T2 = fd.ops.slice( | ||
T0, start_indices=[0, 0], end_indices=[S0, S1 // 2], strides=[1, 1] | ||
) | ||
T3 = fd.ops.slice( | ||
T0, start_indices=[0, S1 // 2], end_indices=[S0, S1], strides=[1, 1] | ||
) | ||
|
||
if dtype in PROMOTE_DTYPES: | ||
T1 = fd.ops.cast(T1, dtype=DataType.Float) | ||
T2 = fd.ops.cast(T2, dtype=DataType.Float) | ||
T3 = fd.ops.cast(T3, dtype=DataType.Float) | ||
|
||
T4 = fd.ops.neg(T2) | ||
T5 = fd.ops.exp(T4) | ||
S6 = fd.define_scalar(1.00000, dtype=DataType.Double) | ||
T7 = fd.ops.add(S6, T5) | ||
T8 = fd.ops.reciprocal(T7) | ||
T9 = fd.ops.mul(T2, T8) | ||
T10 = fd.ops.mul(T3, T1) | ||
T11 = fd.ops.mul(T9, T1) | ||
T12 = fd.ops.mul(T8, T10) | ||
T13 = fd.ops.mul(T2, T10) | ||
T14 = fd.ops.neg(T13) | ||
T15 = fd.ops.mul(T14, T8) | ||
T16 = fd.ops.mul(T15, T8) | ||
T17 = fd.ops.mul(T16, T5) | ||
T18 = fd.ops.neg(T17) | ||
T19 = fd.ops.add(T12, T18) | ||
S20 = fd.define_scalar(0.00000, dtype=DataType.Double) | ||
T21 = fd.ops.pad(T11, [S1 // 2, 0, 0, 0], S20) | ||
S22 = fd.define_scalar(0.00000, dtype=DataType.Double) | ||
T23 = fd.ops.pad(T19, [0, S1 // 2, 0, 0], S22) | ||
T24 = fd.ops.add(T21, T23) | ||
if dtype in PROMOTE_DTYPES: | ||
T24 = fd.ops.cast(T24, dtype=dtype) | ||
fd.add_output(T24) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IvanYashchuk Could you confirm if this matches with the case you're looking at?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case is slightly different because, currently, Thunder produces different backward definitions depending on whether split
or chunk
was used. Here no cat
is present (backward of slice
is pad
). For split
there's an explicit rule that backward of split
is cat
. I suggest adding in addition to this definition also the one with cat
. It's unclear what's easier for nvFuser to support performantly. The other backward variant can be generated by using x, y = torch.split(inputs[0], inputs[0].shape[-1] // 2, -1)
instead of x, y = torch.chunk(inputs[0], 2, -1)
in forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @IvanYashchuk.
@Priya2698, please also leave a note about the original Thunder definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case is slightly different because, currently, Thunder produces different backward definitions depending on whether
split
orchunk
was used. Here nocat
is present (backward ofslice
ispad
). Forsplit
there's an explicit rule that backward ofsplit
iscat
. I suggest adding in addition to this definition also the one withcat
. It's unclear what's easier for nvFuser to support performantly. The other backward variant can be generated by usingx, y = torch.split(inputs[0], inputs[0].shape[-1] // 2, -1)
instead ofx, y = torch.chunk(inputs[0], 2, -1)
in forward.
If we are adding another variant, we should do the same for fwd. My concern is that the baselines (torch.compile and eager) may have different performances (will have to verify) for the two variants even if the fusion definition is the same.
Wdyt? @naoyam @IvanYashchuk -- is there a preference for only selecting one? The upside of not adding an additional forward benchmark would be saved time on CI and minimize redundancy in nvfuser benchmarking.
Else, we can have essentially have 2 benchmarks, for fwd
and bwd
each.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have two as separate benchmarks but each with a smaller set of input sizes so that the overall workload would be roughly the same?
@Priya2698 Thanks for the PR. Have you checked the performance? Does it match with what Ivan reported? |
Fwd benchmark: http://nv/eke |
Update: I noticed some discrepancy in torch.compile benchmarks (PR #3300) -- I will be re-running these benchmarks and update the results here. |
The fusion definition corresponds to:
The definitions are obtained from Thunder.