Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatibility in flash_attention_2 + Llama + Transformers>=4.43 + Autocast to fp16 #36224

Open
poedator opened this issue Feb 17, 2025 · 1 comment
Labels

Comments

@poedator
Copy link
Contributor

poedator commented Feb 17, 2025

System Info

setting: Inference or Training Llama with Automatic Mixed Precision (AMP) autocast from fp32 to fp16 + FlashAttention 2 (FA2).

I observed that in newer versions of the Transformers library (>=4.43), training (and inference) fails with the error RuntimeError: FlashAttention only supports fp16 and bf16. This error does not occur with GPT2 or other parameter combinations. What is happening?

Given:

  • FA2 supports only fp16/bf16 and fails when it encounters fp32.
  • Autocast does not cast all operations to fp16.

The failure was caused by the fact that In transformers >= 4.43, positional embeddings in Llama are precomputed based on the hidden_states (fp32) and are also output in fp32. This is done in the LlamaModel.forward(), before layers' forwards, using the following code:
position_embeddings = self.rotary_emb(hidden_states, position_ids) link
These embeddings are then passed to the attention mechanism as a parameter.

In the attention class, we have:
cos, sin = position_embeddings link
These fp32 values are then added to the autocasted values of q and k. Autocast ignores the addition, resulting in q_embed being in fp32 type. This causes FA2 to fail. If SDPA is used, it handles the mixed dtypes without issues.

Why didn't this happen before?:
In transformers<=4.41 or so, Llama positional embeddings are recomputed in each layer (inefficiently) based on value_states (fp16 within autocast) and are also output in fp16. Hence, no errors occur.

Proposed solutions:

  1. cast cos and sin to q.dtype in apply_rotary_pos_emb() link
  2. cast position_embeddings into the trarget dtype right after creation here
  3. do nothing and let people use sdpa when autocasting from fp32. This is not that bad, since sdpa is quite fast by now.

@ArthurZucker please comment

@poedator poedator added the bug label Feb 17, 2025
@Ialzouby
Copy link

I think a safer approach is to cast cos and sin to q.dtype inside apply_rotary_pos_emb(), instead of modifying LlamaModel.forward().

This keeps positional embeddings in fp32 where needed, avoiding potential precision issues in long-context models.
It prevents breaking compatibility with other attention implementations like sdpa.
It ensures FlashAttention 2 gets fp16/bf16 embeddings without affecting other layers globally.

If this sounds good, I can proceed with a PR using this approach!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants