Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding nnx Gemma2-2b (including overall fixes) to examples/gemma #4587

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/gemma/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def __init__(self, dim: int, *, rngs: nnx.Rngs):

def __call__(self, x: Array) -> Array:
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)

# Note that GDM gemma has:
# """ Jax.lax.rsqrt is used because it returns different floats than jnp.reciprocal(jnp.sqrt(var + 1e-06)) """
#normed_inputs = jnp.asarray( x * jax.lax.rsqrt(var + 1e-06) )
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))

# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
Expand Down
41 changes: 34 additions & 7 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def __call__(
query_proj = self.q_einsum(x)
key_proj, value_proj = self.kv_einsum(x)

# mdda : Where is qk_norm? : Apparently self.use_qk_norm not used in deep-mind gemma

query_proj = positional_embeddings.apply_rope(
query_proj,
segment_pos,
Expand All @@ -160,7 +162,20 @@ def __call__(
cache['k'], key_proj, slice_indices
)

logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)
# mdda: Gemma2 needs GQA branch like https://github.com/google-deepmind/gemma/blob/main/gemma/modules.py#L176 mdda
num_kv_heads = self.num_kv_heads
use_gqa = (num_kv_heads != self.num_heads and num_kv_heads>1)
if use_gqa:
# Reshape matrices to enable einsums over groups.
b, t, kg, h = query_scaled.shape
query_scaled = query_scaled.reshape(
(b, t, num_kv_heads, int(kg / num_kv_heads), h)
)
logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj)
b, t, k, g, s = logits.shape
logits = logits.reshape((b, t, k * g, s))
else:
logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj)

if self.attn_logits_soft_cap is not None:
logits = jnp.tanh(logits / self.attn_logits_soft_cap)
Expand All @@ -180,7 +195,19 @@ def __call__(
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
self.sow_config.maybe_sow_attn_logits_topk(padded_logits, self)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)

# mdda : Gemma2 needs GQA branch like https://github.com/google-deepmind/gemma/blob/main/gemma/modules.py#L208
if use_gqa:
# Reshape matrices to enable einsums over groups.
b, t, kg, h = probs.shape
probs = probs.reshape(
(b, t, num_kv_heads, int(kg / num_kv_heads), h)
)
encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj)
b, t, k, g, h = encoded.shape
encoded = encoded.reshape((b, t, k * g, h))
else:
encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj)
attn_output = self.attn_vec_einsum(encoded)

if cache is not None:
Expand Down Expand Up @@ -342,19 +369,19 @@ def __call__(
cache,
attn_mask,
)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)

if self.use_post_attn_norm:
attn_output = self.post_attn_norm(attn_output)
attn_output += x
self.sow_config.maybe_sow_rs_after_attention(attn_output, self)

residual = attn_output

attn_output = self.pre_ffw_norm(attn_output)
outputs = self.mlp(attn_output)
if self.use_post_ffw_norm:
outputs = self.post_ffw_norm(outputs)
outputs = residual + outputs
self.sow_config.maybe_sow_rs_after_ffw(outputs, self)

return cache, outputs

@property
Expand Down
62 changes: 46 additions & 16 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def from_path(cls, path: str) -> TransformerConfig:
try:
model = metadata['somewhere in orbax checkpoint']

if model in ('gemma-2-27-pt', 'gemma-2-27-it'):
return cls.gemma_27b()
elif model in ('gemma-2-9-pt', 'gemma-2-9-it'):
return cls.gemma_9b()
if model in ('gemma-2-9-pt', 'gemma-2-9-it'):
return cls.gemma2_9b()
elif model in ('gemma-2-27-pt', 'gemma-2-27-it'):
return cls.gemma2_27b()
except KeyError:
# V1 model that does not include model metadata.
# Fall back to previous method
Expand Down Expand Up @@ -128,30 +128,31 @@ def gemma_7b(cls):
)

@classmethod
def gemma_27b(cls):
num_layers = 46
def gemma2_2b(cls):
num_layers = 26
return cls(
num_layers=num_layers,
num_embed=256128,
embed_dim=4608,
hidden_dim=72728,
num_heads=32,
head_dim=128,
num_kv_heads=16,
embed_dim=2304,
hidden_dim=9216,
num_heads=8,
head_dim=256,
num_kv_heads=4,
final_logit_softcap=30.0,
use_post_attn_norm=True,
use_post_ffw_norm=True,
attention_types=(
modules.AttentionType.LOCAL_SLIDING,
modules.AttentionType.GLOBAL,
)
* int(num_layers / 2),
use_post_attn_norm=True,
use_post_ffw_norm=True,
#query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM,
attn_logits_soft_cap=50.0,
sliding_window_size=4096,
)

@classmethod
def gemma_9b(cls):
def gemma2_9b(cls):
num_layers = 42
return cls(
num_layers=num_layers,
Expand All @@ -169,10 +170,35 @@ def gemma_9b(cls):
* int(num_layers / 2),
use_post_attn_norm=True,
use_post_ffw_norm=True,
#query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM,
attn_logits_soft_cap=50.0,
sliding_window_size=4096,
)

@classmethod
def gemma2_27b(cls):
print("Note : BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS is not implemented as a QueryPreAttentionNormalisation type in Attention")
num_layers = 46
return cls(
num_layers=num_layers,
num_embed=256128,
embed_dim=4608,
hidden_dim=72728,
num_heads=32,
head_dim=128,
num_kv_heads=16,
final_logit_softcap=30.0,
attention_types=(
modules.AttentionType.LOCAL_SLIDING,
modules.AttentionType.GLOBAL,
)
* int(num_layers / 2),
use_post_attn_norm=True,
use_post_ffw_norm=True,
#query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS, # in 'real gemma'
attn_logits_soft_cap=50.0,
sliding_window_size=4096,
)


def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
new_key = []
Expand All @@ -188,6 +214,8 @@ def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
elif k == 'linear':
new_key.append('down_proj')
new_key.append('kernel')
elif k == 'post_attention_norm': # gemma2-2b has misnamed key
new_key.append('post_attn_norm')
else:
new_key.append(k)

Expand Down Expand Up @@ -285,6 +313,7 @@ def __call__(
new_cache = None if cache is None else {}
x = self.embedder.encode(last_tokens)
self.sow_config.maybe_sow_embeddings(x, self)

for i, layer in enumerate(self.layers):
layer_name = f'layer_{i}'
layer_cache = cache[layer_name] if cache else None
Expand All @@ -294,6 +323,7 @@ def __call__(
layer_cache,
attention_mask,
)

if cache is not None:
new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch

Expand All @@ -303,7 +333,7 @@ def __call__(
if self.final_logits_softcap is not None:
logits /= self.final_logits_softcap
logits = jnp.tanh(logits) * self.final_logits_softcap

return logits, new_cache # pytype: disable=bad-return-type

@property
Expand Down
Loading