You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
setting: Inference or Training Llama with Automatic Mixed Precision (AMP) autocast from fp32 to fp16 + FlashAttention 2 (FA2).
I observed that in newer versions of the Transformers library (>=4.43), training (and inference) fails with the error RuntimeError: FlashAttention only supports fp16 and bf16. This error does not occur with GPT2 or other parameter combinations. What is happening?
Given:
FA2 supports only fp16/bf16 and fails when it encounters fp32.
Autocast does not cast all operations to fp16.
The failure was caused by the fact that In transformers >= 4.43, positional embeddings in Llama are precomputed based on the hidden_states (fp32) and are also output in fp32. This is done in the LlamaModel.forward(), before layers' forwards, using the following code: position_embeddings = self.rotary_emb(hidden_states, position_ids)link
These embeddings are then passed to the attention mechanism as a parameter.
In the attention class, we have: cos, sin = position_embeddingslink
These fp32 values are then added to the autocasted values of q and k. Autocast ignores the addition, resulting in q_embed being in fp32 type. This causes FA2 to fail. If SDPA is used, it handles the mixed dtypes without issues.
Why didn't this happen before?:
In transformers<=4.41 or so, Llama positional embeddings are recomputed in each layer (inefficiently) based on value_states (fp16 within autocast) and are also output in fp16. Hence, no errors occur.
Proposed solutions:
cast cos and sin to q.dtype in apply_rotary_pos_emb()link
cast position_embeddings into the trarget dtype right after creation here
do nothing and let people use sdpa when autocasting from fp32. This is not that bad, since sdpa is quite fast by now.
I think a safer approach is to cast cos and sin to q.dtype inside apply_rotary_pos_emb(), instead of modifying LlamaModel.forward().
This keeps positional embeddings in fp32 where needed, avoiding potential precision issues in long-context models.
It prevents breaking compatibility with other attention implementations like sdpa.
It ensures FlashAttention 2 gets fp16/bf16 embeddings without affecting other layers globally.
If this sounds good, I can proceed with a PR using this approach!
System Info
setting: Inference or Training Llama with Automatic Mixed Precision (AMP) autocast from fp32 to fp16 + FlashAttention 2 (FA2).
I observed that in newer versions of the Transformers library (>=4.43), training (and inference) fails with the error
RuntimeError: FlashAttention only supports fp16 and bf16
. This error does not occur withGPT2
or other parameter combinations. What is happening?Given:
The failure was caused by the fact that In
transformers >= 4.43
, positional embeddings in Llama are precomputed based on the hidden_states (fp32) and are also output in fp32. This is done in theLlamaModel.forward()
, before layers' forwards, using the following code:position_embeddings = self.rotary_emb(hidden_states, position_ids)
linkThese embeddings are then passed to the attention mechanism as a parameter.
In the attention class, we have:
cos, sin = position_embeddings
linkThese
fp32
values are then added to the autocasted values ofq
andk
. Autocast ignores the addition, resulting inq_embed
being infp32
type. This causes FA2 to fail. If SDPA is used, it handles the mixed dtypes without issues.Why didn't this happen before?:
In transformers<=4.41 or so, Llama positional embeddings are recomputed in each layer (inefficiently) based on
value_states
(fp16
within autocast) and are also output infp16
. Hence, no errors occur.Proposed solutions:
cos
andsin
to q.dtype inapply_rotary_pos_emb()
linksdpa
when autocasting fromfp32
. This is not that bad, since sdpa is quite fast by now.@ArthurZucker please comment
The text was updated successfully, but these errors were encountered: