diff --git a/examples/gemma/layers.py b/examples/gemma/layers.py index 5dafbd15d..d3e258d35 100644 --- a/examples/gemma/layers.py +++ b/examples/gemma/layers.py @@ -51,7 +51,12 @@ def __init__(self, dim: int, *, rngs: nnx.Rngs): def __call__(self, x: Array) -> Array: var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) + + # Note that GDM gemma has: + # """ Jax.lax.rsqrt is used because it returns different floats than jnp.reciprocal(jnp.sqrt(var + 1e-06)) """ + #normed_inputs = jnp.asarray( x * jax.lax.rsqrt(var + 1e-06) ) normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) + # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 5eb6fa578..d72f73f07 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -135,6 +135,8 @@ def __call__( query_proj = self.q_einsum(x) key_proj, value_proj = self.kv_einsum(x) + # mdda : Where is qk_norm? : Apparently self.use_qk_norm not used in deep-mind gemma + query_proj = positional_embeddings.apply_rope( query_proj, segment_pos, @@ -160,7 +162,20 @@ def __call__( cache['k'], key_proj, slice_indices ) - logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) + # mdda: Gemma2 needs GQA branch like https://github.com/google-deepmind/gemma/blob/main/gemma/modules.py#L176 mdda + num_kv_heads = self.num_kv_heads + use_gqa = (num_kv_heads != self.num_heads and num_kv_heads>1) + if use_gqa: + # Reshape matrices to enable einsums over groups. + b, t, kg, h = query_scaled.shape + query_scaled = query_scaled.reshape( + (b, t, num_kv_heads, int(kg / num_kv_heads), h) + ) + logits = jnp.einsum('BTKGH,BSKH->BTKGS', query_scaled, key_proj) + b, t, k, g, s = logits.shape + logits = logits.reshape((b, t, k * g, s)) + else: + logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) if self.attn_logits_soft_cap is not None: logits = jnp.tanh(logits / self.attn_logits_soft_cap) @@ -180,7 +195,19 @@ def __call__( padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) self.sow_config.maybe_sow_attn_logits_topk(padded_logits, self) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) - encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) + + # mdda : Gemma2 needs GQA branch like https://github.com/google-deepmind/gemma/blob/main/gemma/modules.py#L208 + if use_gqa: + # Reshape matrices to enable einsums over groups. + b, t, kg, h = probs.shape + probs = probs.reshape( + (b, t, num_kv_heads, int(kg / num_kv_heads), h) + ) + encoded = jnp.einsum('BTKGS,BSKH->BTKGH', probs, value_proj) + b, t, k, g, h = encoded.shape + encoded = encoded.reshape((b, t, k * g, h)) + else: + encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) attn_output = self.attn_vec_einsum(encoded) if cache is not None: @@ -342,19 +369,19 @@ def __call__( cache, attn_mask, ) - attn_output += x - residual = attn_output - attn_output = self.pre_ffw_norm(attn_output) - if self.use_post_attn_norm: attn_output = self.post_attn_norm(attn_output) + attn_output += x self.sow_config.maybe_sow_rs_after_attention(attn_output, self) - + residual = attn_output + + attn_output = self.pre_ffw_norm(attn_output) outputs = self.mlp(attn_output) if self.use_post_ffw_norm: outputs = self.post_ffw_norm(outputs) outputs = residual + outputs self.sow_config.maybe_sow_rs_after_ffw(outputs, self) + return cache, outputs @property diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index cdf607c1a..64c6a8692 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -59,10 +59,10 @@ def from_path(cls, path: str) -> TransformerConfig: try: model = metadata['somewhere in orbax checkpoint'] - if model in ('gemma-2-27-pt', 'gemma-2-27-it'): - return cls.gemma_27b() - elif model in ('gemma-2-9-pt', 'gemma-2-9-it'): - return cls.gemma_9b() + if model in ('gemma-2-9-pt', 'gemma-2-9-it'): + return cls.gemma2_9b() + elif model in ('gemma-2-27-pt', 'gemma-2-27-it'): + return cls.gemma2_27b() except KeyError: # V1 model that does not include model metadata. # Fall back to previous method @@ -128,30 +128,31 @@ def gemma_7b(cls): ) @classmethod - def gemma_27b(cls): - num_layers = 46 + def gemma2_2b(cls): + num_layers = 26 return cls( num_layers=num_layers, num_embed=256128, - embed_dim=4608, - hidden_dim=72728, - num_heads=32, - head_dim=128, - num_kv_heads=16, + embed_dim=2304, + hidden_dim=9216, + num_heads=8, + head_dim=256, + num_kv_heads=4, final_logit_softcap=30.0, - use_post_attn_norm=True, - use_post_ffw_norm=True, attention_types=( modules.AttentionType.LOCAL_SLIDING, modules.AttentionType.GLOBAL, ) * int(num_layers / 2), + use_post_attn_norm=True, + use_post_ffw_norm=True, + #query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, attn_logits_soft_cap=50.0, sliding_window_size=4096, ) @classmethod - def gemma_9b(cls): + def gemma2_9b(cls): num_layers = 42 return cls( num_layers=num_layers, @@ -169,10 +170,35 @@ def gemma_9b(cls): * int(num_layers / 2), use_post_attn_norm=True, use_post_ffw_norm=True, + #query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM, + attn_logits_soft_cap=50.0, + sliding_window_size=4096, + ) + + @classmethod + def gemma2_27b(cls): + print("Note : BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS is not implemented as a QueryPreAttentionNormalisation type in Attention") + num_layers = 46 + return cls( + num_layers=num_layers, + num_embed=256128, + embed_dim=4608, + hidden_dim=72728, + num_heads=32, + head_dim=128, + num_kv_heads=16, + final_logit_softcap=30.0, + attention_types=( + modules.AttentionType.LOCAL_SLIDING, + modules.AttentionType.GLOBAL, + ) + * int(num_layers / 2), + use_post_attn_norm=True, + use_post_ffw_norm=True, + #query_pre_attn_norm=transformer_lib.QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS, # in 'real gemma' attn_logits_soft_cap=50.0, sliding_window_size=4096, ) - def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: new_key = [] @@ -188,6 +214,8 @@ def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: elif k == 'linear': new_key.append('down_proj') new_key.append('kernel') + elif k == 'post_attention_norm': # gemma2-2b has misnamed key + new_key.append('post_attn_norm') else: new_key.append(k) @@ -285,6 +313,7 @@ def __call__( new_cache = None if cache is None else {} x = self.embedder.encode(last_tokens) self.sow_config.maybe_sow_embeddings(x, self) + for i, layer in enumerate(self.layers): layer_name = f'layer_{i}' layer_cache = cache[layer_name] if cache else None @@ -294,6 +323,7 @@ def __call__( layer_cache, attention_mask, ) + if cache is not None: new_cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch @@ -303,7 +333,7 @@ def __call__( if self.final_logits_softcap is not None: logits /= self.final_logits_softcap logits = jnp.tanh(logits) * self.final_logits_softcap - + return logits, new_cache # pytype: disable=bad-return-type @property