You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Note: This is only an experiment and no real model execution is currently blocked by this!
When working on a memory usage fix for Gemma 7B in Thunder (Lightning-AI/lightning-thunder#474 (comment)) I discovered an interesting property of transformers when combined with fusions and rematerialization and hit larger than usual compile times with nvFuser.
Transformers follow the following pattern (adapted from LitGPT link1, link2):
Here we must return the x_normed_1 tensor as the output of the fused operation as it's used in the attention operation. We also must return x_normed_2 because it's consumed by mlp. The x_add_mlp tensor is returned as a residual tensor to be added to the attention output in the next iteration. Similarly, the x_add_attn tensor is returned as a residual tensor to be added to the mlp output in the next iteration.
In addition to that, we must keep both attention_output and mlp_output in memory for backward pass. If we keep them in memory, we can avoid returning and materializing x_add_mlp and x_add_attn because they can be recomputed.
The following patch for Thunder provides this residual path recomputation:
diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py
index 7db7b88f..4967f0c9 100644
--- a/thunder/core/rematerialization.py+++ b/thunder/core/rematerialization.py@@ -347,7 +347,7 @@ def find_cut(
def add_edges(var):
var_name = var.name
weight = get_weight(var)
- weight = weight / 2.0 if var_name in (x.name for x in producer.args) else weight+ weight = weight * 0.0 if var_name in (x.name for x in producer.args) else weight
add_edge(var_name + "_in", var_name + "_out", capacity=weight)
for user in combined_consumers._dict.get(var_name, tuple()):
if user.sym.id in sym_skip_list:
This forces Thunder to generate different fusions for each iteration and each fusion is different from the previous one by a new add operation added and another input added to the fusion. I think nvFuser treats each of these fusions as a completely new one and applies full compilation to them. This scales badly with the number of layers.
Is there anything nvFuser could improve in this scenario?
Here's the code to reproduce the behavior (first apply the patch above for Thunder, then vary NUM_LAYERS and observe the time it takes to compile all fusions):
importtorchimportthunderfromtimeimporttimeimportargparseparser=argparse.ArgumentParser()
parser.add_argument("--num_layers", type=int, default=10)
NUM_LAYERS=parser.parse_args().num_layersdefnorm_1(x):
# Replaced layer norm with tanh for simplicityreturntorch.nn.functional.tanh(x)
defnorm_2(x):
# Replaced layer norm with tanh for simplicityreturntorch.nn.functional.tanh(x)
defmlp(weight, x):
# Replaced MLP with linear layer for simplicityreturnweight @ xdefattn(weight, x):
# Replaced attention with linear layer for simplicityreturnweight @ xdeffunc(weight, x):
x_add_mlp=xglobalNUM_LAYERSfor_inrange(NUM_LAYERS):
x_normed_1=norm_1(x_add_mlp)
attention_output=attn(weight, x_normed_1)
x_add_attn=attention_output+x_add_mlpx_normed_2=norm_2(x_add_attn)
mlp_output=mlp(weight, x_normed_2)
x_add_mlp=mlp_output+x_add_attnreturnx_add_mlpweight=torch.randn(10, 10, device="cuda", requires_grad=True)
x=torch.randn(10, 10, device="cuda", requires_grad=True)
jfunc=thunder.jit(func)
start=time()
out=jfunc(weight, x)
torch.cuda.synchronize()
end=time()
# print(thunder.last_traces(jfunc)[-1])print(f"Compilation time {NUM_LAYERS=}: {end-start}")
output with AMD Ryzen 9 5950X (time is in seconds)
Compilation time NUM_LAYERS=1: 0.32010936737060547
Compilation time NUM_LAYERS=2: 0.5338308811187744
Compilation time NUM_LAYERS=4: 1.0335056781768799
Compilation time NUM_LAYERS=8: 2.362748861312866
Compilation time NUM_LAYERS=16: 6.8711018562316895
Compilation time NUM_LAYERS=32: 27.141274452209473
Compilation time NUM_LAYERS=64: 161.30480885505676
Compilation time NUM_LAYERS=128: 1392.7565586566925 (20 minutes!)
loglog plot:
The text was updated successfully, but these errors were encountered:
Update:
For adding this benchmark, there are 2 ways:
Generate fusion definition using nvf_enable_matmul=True. This gives us one fusion definition to be benchmarked, which is supported by the current benchmarking infra. This seems to get stuck for num_layer=16 onwards (investigating this since the reported compile time above was ~6s).
I will be looking at supporting benchmarking functions consisting of multiple fusion definitions. Note, that, we use fusion profiler for host time measurements, hence, any host overhead incurred in other parts of the function will not be captured.
Note: This is only an experiment and no real model execution is currently blocked by this!
When working on a memory usage fix for Gemma 7B in Thunder (Lightning-AI/lightning-thunder#474 (comment)) I discovered an interesting property of transformers when combined with fusions and rematerialization and hit larger than usual compile times with nvFuser.
Transformers follow the following pattern (adapted from LitGPT link1, link2):
Operations can be fused into two regions:
and
Here we must return the
x_normed_1
tensor as the output of the fused operation as it's used in the attention operation. We also must returnx_normed_2
because it's consumed by mlp. Thex_add_mlp
tensor is returned as a residual tensor to be added to the attention output in the next iteration. Similarly, thex_add_attn
tensor is returned as a residual tensor to be added to the mlp output in the next iteration.In addition to that, we must keep both
attention_output
andmlp_output
in memory for backward pass. If we keep them in memory, we can avoid returning and materializingx_add_mlp
andx_add_attn
because they can be recomputed.The following patch for Thunder provides this residual path recomputation:
This forces Thunder to generate different fusions for each iteration and each fusion is different from the previous one by a new add operation added and another input added to the fusion. I think nvFuser treats each of these fusions as a completely new one and applies full compilation to them. This scales badly with the number of layers.
Is there anything nvFuser could improve in this scenario?
Here's the code to reproduce the behavior (first apply the patch above for Thunder, then vary NUM_LAYERS and observe the time it takes to compile all fusions):
run the script with
output with AMD Ryzen 9 5950X (time is in seconds)
loglog plot:
The text was updated successfully, but these errors were encountered: