Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache misses for fusion definitions different by a few ops compiled one after another #2916

Open
IvanYashchuk opened this issue Sep 6, 2024 · 3 comments

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Sep 6, 2024

Note: This is only an experiment and no real model execution is currently blocked by this!

When working on a memory usage fix for Gemma 7B in Thunder (Lightning-AI/lightning-thunder#474 (comment)) I discovered an interesting property of transformers when combined with fusions and rematerialization and hit larger than usual compile times with nvFuser.

Transformers follow the following pattern (adapted from LitGPT link1, link2):

x_add_mlp = x
for i in range(num_layers):
    x_normed_1 = norm_1(x_add_mlp)
    attention_output = attention(x_normed_1) # NONFUSIBLE
    x_add_attn = attention_output + x_add_mlp
    x_normed_2 = norm_2(x_add_attn)
    mlp_output = mlp(x_normed_2)        # NONFUSIBLE
    x_add_mlp = mlp_output + x_add_attn

Operations can be fused into two regions:

def inner_iter1(mlp_output, x_add_attn):
    x_add_mlp = mlp_output + x_add_attn
    x_normed_1 = norm_1(x_add_mlp)
    return x_normed_1, x_add_mlp

and

def inner_iter2(attention_output, x_add_mlp):
     x_add_attn = attention_output + x_add_mlp
     x_normed_2 = norm_2(x_add_attn)
     return x_normed_2, x_add_attn

Here we must return the x_normed_1 tensor as the output of the fused operation as it's used in the attention operation. We also must return x_normed_2 because it's consumed by mlp. The x_add_mlp tensor is returned as a residual tensor to be added to the attention output in the next iteration. Similarly, the x_add_attn tensor is returned as a residual tensor to be added to the mlp output in the next iteration.

In addition to that, we must keep both attention_output and mlp_output in memory for backward pass. If we keep them in memory, we can avoid returning and materializing x_add_mlp and x_add_attn because they can be recomputed.

The following patch for Thunder provides this residual path recomputation:

diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py
index 7db7b88f..4967f0c9 100644
--- a/thunder/core/rematerialization.py
+++ b/thunder/core/rematerialization.py
@@ -347,7 +347,7 @@ def find_cut(
     def add_edges(var):
         var_name = var.name
         weight = get_weight(var)
-        weight = weight / 2.0 if var_name in (x.name for x in producer.args) else weight
+        weight = weight * 0.0 if var_name in (x.name for x in producer.args) else weight
         add_edge(var_name + "_in", var_name + "_out", capacity=weight)
         for user in combined_consumers._dict.get(var_name, tuple()):
             if user.sym.id in sym_skip_list:

This forces Thunder to generate different fusions for each iteration and each fusion is different from the previous one by a new add operation added and another input added to the fusion. I think nvFuser treats each of these fusions as a completely new one and applies full compilation to them. This scales badly with the number of layers.

Is there anything nvFuser could improve in this scenario?

Here's the code to reproduce the behavior (first apply the patch above for Thunder, then vary NUM_LAYERS and observe the time it takes to compile all fusions):

import torch
import thunder
from time import time
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--num_layers", type=int, default=10)

NUM_LAYERS = parser.parse_args().num_layers

def norm_1(x):
    # Replaced layer norm with tanh for simplicity
    return torch.nn.functional.tanh(x)

def norm_2(x):
    # Replaced layer norm with tanh for simplicity
    return torch.nn.functional.tanh(x)

def mlp(weight, x):
    # Replaced MLP with linear layer for simplicity
    return weight @ x

def attn(weight, x):
    # Replaced attention with linear layer for simplicity
    return weight @ x

def func(weight, x):
    x_add_mlp = x
    global NUM_LAYERS
    for _ in range(NUM_LAYERS):
        x_normed_1 = norm_1(x_add_mlp)
        attention_output = attn(weight, x_normed_1)
        x_add_attn = attention_output + x_add_mlp
        x_normed_2 = norm_2(x_add_attn)
        mlp_output = mlp(weight, x_normed_2)
        x_add_mlp = mlp_output + x_add_attn
    return x_add_mlp

weight = torch.randn(10, 10, device="cuda", requires_grad=True)
x = torch.randn(10, 10, device="cuda", requires_grad=True)

jfunc = thunder.jit(func)
start = time()
out = jfunc(weight, x)
torch.cuda.synchronize()
end = time()
# print(thunder.last_traces(jfunc)[-1])
print(f"Compilation time {NUM_LAYERS=}: {end - start}")

run the script with

for i in {0..7} ; do python script.py --num_layers=$((2**i)) ; done

output with AMD Ryzen 9 5950X (time is in seconds)

Compilation time NUM_LAYERS=1: 0.32010936737060547
Compilation time NUM_LAYERS=2: 0.5338308811187744
Compilation time NUM_LAYERS=4: 1.0335056781768799
Compilation time NUM_LAYERS=8: 2.362748861312866
Compilation time NUM_LAYERS=16: 6.8711018562316895
Compilation time NUM_LAYERS=32: 27.141274452209473
Compilation time NUM_LAYERS=64: 161.30480885505676
Compilation time NUM_LAYERS=128: 1392.7565586566925 (20 minutes!)

loglog plot:
Image

@naoyam
Copy link
Collaborator

naoyam commented Sep 6, 2024

Thanks @IvanYashchuk.

@Priya2698 As mentioned offline, first of all, please add it as a benchmark.

@Priya2698
Copy link
Collaborator

Update:
For adding this benchmark, there are 2 ways:

  1. Generate fusion definition using nvf_enable_matmul=True. This gives us one fusion definition to be benchmarked, which is supported by the current benchmarking infra. This seems to get stuck for num_layer=16 onwards (investigating this since the reported compile time above was ~6s).
  2. I will be looking at supporting benchmarking functions consisting of multiple fusion definitions. Note, that, we use fusion profiler for host time measurements, hence, any host overhead incurred in other parts of the function will not be captured.

@Priya2698
Copy link
Collaborator

CC: @kevinstephano

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants