Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark split + silu + mul fusion #2921

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Sep 9, 2024

The fusion definition corresponds to:

def split_silu(xy): 
    x, y = torch.chunk(inputs[0], 2, -1)
    return torch.nn.functional.silu(x) * y

The definitions are obtained from Thunder.

Comment on lines 12 to 39
def split_silu_fwd_fusion(fd: FusionDefinition, dtype: DataType, size: tuple):
T0 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=dtype,
is_cpu=False,
stride_order=[1, 0],
)
S0, S1 = size
T1 = fd.ops.slice(
T0, start_indices=[0, 0], end_indices=[S0, S1 // 2], strides=[1, 1]
)
T2 = fd.ops.slice(
T0, start_indices=[0, S1 // 2], end_indices=[S0, S1], strides=[1, 1]
)
if dtype in PROMOTE_DTYPES:
T1 = fd.ops.cast(T1, dtype=DataType.Float)
T2 = fd.ops.cast(T2, dtype=DataType.Float)
T3 = fd.ops.neg(T1)
T4 = fd.ops.exp(T3)
S5 = fd.define_scalar(1.00000, dtype=DataType.Double)
T6 = fd.ops.add(S5, T4)
T7 = fd.ops.reciprocal(T6)
T8 = fd.ops.mul(T1, T7)
T9 = fd.ops.mul(T8, T2)
if dtype in PROMOTE_DTYPES:
T9 = fd.ops.cast(T9, dtype=dtype)
fd.add_output(T9)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk Could you confirm if this matches with the case you're looking at?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the forward definition matches the case I was looking at.

Comment on lines 13 to 64
def split_silu_bwd_fusion(fd: FusionDefinition, dtype: DataType, size: tuple):
T0 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=dtype,
is_cpu=False,
stride_order=[1, 0],
)
T1 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=dtype,
is_cpu=False,
stride_order=[1, 0],
)
S0, S1 = size
T2 = fd.ops.slice(
T0, start_indices=[0, 0], end_indices=[S0, S1 // 2], strides=[1, 1]
)
T3 = fd.ops.slice(
T0, start_indices=[0, S1 // 2], end_indices=[S0, S1], strides=[1, 1]
)

if dtype in PROMOTE_DTYPES:
T1 = fd.ops.cast(T1, dtype=DataType.Float)
T2 = fd.ops.cast(T2, dtype=DataType.Float)
T3 = fd.ops.cast(T3, dtype=DataType.Float)

T4 = fd.ops.neg(T2)
T5 = fd.ops.exp(T4)
S6 = fd.define_scalar(1.00000, dtype=DataType.Double)
T7 = fd.ops.add(S6, T5)
T8 = fd.ops.reciprocal(T7)
T9 = fd.ops.mul(T2, T8)
T10 = fd.ops.mul(T3, T1)
T11 = fd.ops.mul(T9, T1)
T12 = fd.ops.mul(T8, T10)
T13 = fd.ops.mul(T2, T10)
T14 = fd.ops.neg(T13)
T15 = fd.ops.mul(T14, T8)
T16 = fd.ops.mul(T15, T8)
T17 = fd.ops.mul(T16, T5)
T18 = fd.ops.neg(T17)
T19 = fd.ops.add(T12, T18)
S20 = fd.define_scalar(0.00000, dtype=DataType.Double)
T21 = fd.ops.pad(T11, [S1 // 2, 0, 0, 0], S20)
S22 = fd.define_scalar(0.00000, dtype=DataType.Double)
T23 = fd.ops.pad(T19, [0, S1 // 2, 0, 0], S22)
T24 = fd.ops.add(T21, T23)
if dtype in PROMOTE_DTYPES:
T24 = fd.ops.cast(T24, dtype=dtype)
fd.add_output(T24)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IvanYashchuk Could you confirm if this matches with the case you're looking at?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is slightly different because, currently, Thunder produces different backward definitions depending on whether split or chunk was used. Here no cat is present (backward of slice is pad). For split there's an explicit rule that backward of split is cat. I suggest adding in addition to this definition also the one with cat. It's unclear what's easier for nvFuser to support performantly. The other backward variant can be generated by using x, y = torch.split(inputs[0], inputs[0].shape[-1] // 2, -1) instead of x, y = torch.chunk(inputs[0], 2, -1) in forward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @IvanYashchuk.

@Priya2698, please also leave a note about the original Thunder definition.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is slightly different because, currently, Thunder produces different backward definitions depending on whether split or chunk was used. Here no cat is present (backward of slice is pad). For split there's an explicit rule that backward of split is cat. I suggest adding in addition to this definition also the one with cat. It's unclear what's easier for nvFuser to support performantly. The other backward variant can be generated by using x, y = torch.split(inputs[0], inputs[0].shape[-1] // 2, -1) instead of x, y = torch.chunk(inputs[0], 2, -1) in forward.

If we are adding another variant, we should do the same for fwd. My concern is that the baselines (torch.compile and eager) may have different performances (will have to verify) for the two variants even if the fusion definition is the same.
Wdyt? @naoyam @IvanYashchuk -- is there a preference for only selecting one? The upside of not adding an additional forward benchmark would be saved time on CI and minimize redundancy in nvfuser benchmarking.
Else, we can have essentially have 2 benchmarks, for fwd and bwd each.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have two as separate benchmarks but each with a smaller set of input sizes so that the overall workload would be roughly the same?

@naoyam
Copy link
Collaborator

naoyam commented Sep 9, 2024

@Priya2698 Thanks for the PR. Have you checked the performance? Does it match with what Ivan reported?

@Priya2698
Copy link
Collaborator Author

Priya2698 commented Sep 10, 2024

@Priya2698 Thanks for the PR. Have you checked the performance? Does it match with what Ivan reported?

Fwd benchmark: http://nv/eke
Bwd benchmark: http://nv/ekf
I can see the regression against torchcompile for forward benchmark, but not backward.
I will add the torch.split version for backward and re-run.

@Priya2698
Copy link
Collaborator Author

Update:

I noticed some discrepancy in torch.compile benchmarks (PR #3300) -- I will be re-running these benchmarks and update the results here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants