Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for pytorch autograd #60

Open
bghira opened this issue Dec 3, 2024 · 3 comments
Open

support for pytorch autograd #60

bghira opened this issue Dec 3, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@bghira
Copy link

bghira commented Dec 3, 2024

hello, thank you for the outstanding work in quantised attention mechanisms. we see a 300% speedup during training on a 4090 based cluster. of course, this is because the QKV linears are not receiving gradients on the backward pass.

i've done some hackjob implementation that approximates the gradient but this is really not good results - however, the speedup remains mostly in-tact.

it's seeming like an incredible challenge to support quantised backward pass with meaningful gradients, keeping rounded and quantised values in the computation graph.

i know the suggestion in another issue thread (unrelated to autograd) is that since H100 is used for large-scale pretraining, that it would be most useful to implement the forward pass for it before working on backward.

however, it's notably faster to train with even an approximated gradient using quantised attention for 12B models like Flux.

just wanted to open this issue to keep it trackable by others who are waiting for news or updates regarding this challenge.

@jason-huang03
Copy link
Member

Thanks for your attention! Quantizing the backward pass poses additional and unexplored challenges. However, it also means great potentials and opportunities. We shall try to solve this challenge with our best effort in the future.

@lodestone-rock
Copy link

@jason-huang03 maybe just use rematerialization trick (rematerialize) to get the ctx for backward pass ?
basically the same approach for flash attention
then we could work out the kernel from there for the backward

@bghira
Copy link
Author

bghira commented Dec 3, 2024

a test for checking whether this is resolved is simply;

import torch
from sageattention import sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp16_cuda

# Sample tensors with requires_grad=True
batch_size = 2
num_heads = 4
seq_length = 8
head_dim = 64

q = torch.randn((batch_size, num_heads, seq_length, head_dim), requires_grad=True, device='cuda', dtype=torch.float16)
k = torch.randn_like(q, requires_grad=True)
v = torch.randn_like(q, requires_grad=True)

# Run the attention function
try:
   output = sageattn_qk_int8_pv_fp16_triton(q, k, v)
except:
   pass
print(f"sage triton output.requires_grad: {output.requires_grad}")
try:
   output = sageattn_qk_int8_pv_fp16_cuda(q, k, v)
except:
   pass
print(f"sage cuda output.requires_grad: {output.requires_grad}")
try:
   output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
except:
   pass
print(f"sdpa output.requires_grad: {output.requires_grad}")

and the output:

% python test_sage.py
sage triton output.requires_grad: False
sage cuda output.requires_grad: False
sdpa output.requires_grad: True

@jason-huang03 jason-huang03 added the enhancement New feature or request label Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants