Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vision for Gemma3 #2170

Merged
merged 57 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
e87f9be
Edit causal LM preprocessor to handle images
abheesht17 Mar 26, 2025
c2a0571
Fix interleaving layer
abheesht17 Mar 27, 2025
231f4f3
Fix interleaving UTs + verify generate output
abheesht17 Mar 27, 2025
2aa7c2f
Workaround for scatter_update bug on Torch
abheesht17 Mar 27, 2025
6661a51
Allow lists and ragged tensors as images
abheesht17 Mar 27, 2025
f993170
Allow unbatched inputs for preprocessor
abheesht17 Mar 27, 2025
b473678
Make text only case work with tf.data
abheesht17 Mar 27, 2025
9916a00
Preprocessor working with tf.data now
abheesht17 Mar 27, 2025
c77dea5
Add ViT layer LoRA layers
abheesht17 Mar 27, 2025
438a441
Handle unbatched inputs in generate()
abheesht17 Mar 27, 2025
3c47d1d
Fix for jax Array, torch tensor automatic typecasting
abheesht17 Mar 27, 2025
e74c0ef
Move tensor to CPU for Torch
abheesht17 Mar 27, 2025
f8f2beb
Same as prev commit
abheesht17 Mar 27, 2025
14d025a
Always do img preprocessing in float32
abheesht17 Mar 27, 2025
819edf5
Aaargh
abheesht17 Mar 27, 2025
d4291da
Torch fixes
abheesht17 Mar 27, 2025
78054f5
Allow None dtypes
abheesht17 Mar 27, 2025
1b0e202
Workaround for tokenizer issue
abheesht17 Mar 27, 2025
2aeddc4
Small fix
abheesht17 Mar 27, 2025
bcac761
Small fix
abheesht17 Mar 27, 2025
f18ab4b
Typo
abheesht17 Mar 28, 2025
cb1d58b
Override normalize_generate_inputs to handle unbatched images
abheesht17 Mar 28, 2025
263160f
Change upranking logic for generation
abheesht17 Mar 28, 2025
89243ad
Some doc-string brushing up
abheesht17 Mar 29, 2025
1718c91
Add a cache update mask for proper batched input generation
abheesht17 Mar 29, 2025
047bcd9
Force ViT to use dtype bfloat16
abheesht17 Mar 30, 2025
30c3ad6
Small edit
abheesht17 Mar 30, 2025
63d597f
Brush up doc-strings, address a few comments
abheesht17 Mar 30, 2025
c6fc29e
Fix tests
abheesht17 Mar 30, 2025
164f1db
Make existing tests pass
abheesht17 Mar 31, 2025
f527095
Add backbone, causal LM preprocessor tests
abheesht17 Mar 31, 2025
e4c61e3
Remove print statements
abheesht17 Mar 31, 2025
04b2fe1
Add causal LM tests
abheesht17 Mar 31, 2025
b2b7038
Reduce duplicate code in preprocessor
abheesht17 Mar 31, 2025
61a382d
Add getter/setter for max_images_per_prompt
abheesht17 Mar 31, 2025
d2b49a7
Fix normalize generate inputs
abheesht17 Mar 31, 2025
83da842
Temp change
abheesht17 Mar 31, 2025
f9a0b49
Address nits
abheesht17 Apr 1, 2025
193592a
Dtypes
abheesht17 Apr 1, 2025
f8f13d4
Export Gemma3VisionEncoder
abheesht17 Apr 1, 2025
4092199
Small doc-string edit
abheesht17 Apr 1, 2025
043b04b
Fix rename
abheesht17 Apr 1, 2025
25d78e2
Add local, global scaling factors as args to fix text 1B
abheesht17 Apr 1, 2025
b6a23e9
Move sliding window attn before FA block for Gemma3
abheesht17 Apr 1, 2025
709c11c
Update 1B ckpts
abheesht17 Apr 1, 2025
761ddaf
Final 1B ckpts
abheesht17 Apr 1, 2025
333594e
Modify text presets
abheesht17 Apr 2, 2025
c10729e
Copy over Divya's changes for FA, pull in changes later
abheesht17 Apr 2, 2025
1b5e5d6
ViT works fine on bfloat16 [T4]
abheesht17 Apr 2, 2025
1c687fe
Eh, stick with float32 for ViT to maintain parity
abheesht17 Apr 2, 2025
24674e2
Cast indices to int32
abheesht17 Apr 2, 2025
4c76d0d
Try forcing interleaving layer to float32
abheesht17 Apr 2, 2025
aeef3a7
Pull in master changes
abheesht17 Apr 3, 2025
ff8af1f
Revert interleave emb layer dtype back to dtype, indices are correct …
abheesht17 Apr 3, 2025
b03d9ee
Add vision presets
abheesht17 Apr 3, 2025
43712d7
Merge branch 'master' of https://github.com/abheesht17/keras-nlp into…
abheesht17 Apr 3, 2025
2575b57
Final changes to presets file
abheesht17 Apr 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@
Gemma3CausalLMPreprocessor,
)
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
Gemma3VisionEncoder,
)
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
Expand Down
95 changes: 74 additions & 21 deletions keras_hub/src/models/gemma3/gemma3_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,28 @@
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
from keras_hub.src.utils.keras_utils import running_on_gpu
from keras_hub.src.utils.keras_utils import running_on_tpu


class CachedGemma3Attention(keras.layers.Layer):
"""A cached grouped query attention layer for Gemma3.

This is different from Gemma and Gemma2 in several ways:
This is the same as the attention layer used for Gemma and Gemma2. It
exposes a few additional args:

- `use_query_key_norm`: Applies RMS Norm on query, key.
- `rope_wavelength`: RoPE wavelength differs from local to global attention
layers.
- `rope_scaling_factor`: RoPE scaling factor differs from local to global
attention layers.
`use_query_key_norm`: bool. If True, apply RMS normalization on query
and key. For Gemma3, this is True.
`rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
uses 10K for local attention layers and 1M for global attention layers.
`gate_dim_reduction`: int. In the gating layers, the output dimension is
`intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
value is 2. For Gemma3, it is 1.

Moreover, the call() method takes in a `cache_update_mask` so as to make
sure that the key-value cache is updated only for the non-prompt tokens
during generation.
"""

def __init__(
Expand Down Expand Up @@ -139,17 +148,22 @@ def _apply_rope(self, x, start_index):
x = self.rope_layer(x, start_index=start_index)
return x

def _can_use_flash_attention(self):
def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if self.logit_soft_cap is None:
return True
sig = inspect.signature(ops.dot_product_attention)
# We can currently only run soft capped attention for keras >= 3.10
# and only on TPU.
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
if running_on_gpu():
# GPU never supports softcap in the fused op.
if self.logit_soft_cap is not None:
return False
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
sig = inspect.signature(ops.dot_product_attention)
return "attn_logits_soft_cap" in sig.parameters
else:
return False

def _compute_attention(
self,
Expand All @@ -166,7 +180,14 @@ def _compute_attention(
query_normalization = 1 / np.sqrt(
self.hidden_dim // self.num_query_heads
)
if self._can_use_flash_attention():

if self.use_sliding_window_attention and attention_mask is not None:
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index,
)

if self._use_fused_attention_op():
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
Expand Down Expand Up @@ -205,13 +226,8 @@ def _compute_attention(
ops.tanh(attention_logits), self.logit_soft_cap
)

if self.use_sliding_window_attention:
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index,
)

attention_mask = attention_mask[:, None, None, :, :]
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :, :]
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
attention_softmax = ops.cast(attention_softmax, orig_dtype)
Expand Down Expand Up @@ -256,6 +272,7 @@ def call(
attention_mask=None,
cache=None,
cache_update_index=0,
cache_update_mask=None,
training=False,
):
query = self.query_dense(x)
Expand All @@ -275,7 +292,43 @@ def call(

key_update = self._apply_rope(key_update, cache_update_index)
value_update = self.value_dense(x)

# Update cache. Note that the cache is updated only if the
# corresponding `cache_update_mask` value is True. This is to
# ensure that we don't update the cache at indices corresponding to
# the prompt. For Gemma3, in particular, this is useful because
# image tokens have bidirectional attention. During generation,
# if we have uneven inputs during generation, we might end up having
# causal attention between image tokens, which is incorrect. To
# avoid this, bidirectional attention is taken care of during
# the prefill step, and during generation, the cache is not updated
# for the prompt. The shape of `cache_update_mask` is
# `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
# token-by-token.
start = [0, cache_update_index, 0, 0]
if cache_update_mask is not None:
cache_update_mask = ops.expand_dims(
ops.expand_dims(cache_update_mask, axis=-1),
axis=-1,
)
key_original = ops.slice(
key_cache, start, ops.shape(key_update)
)
value_original = ops.slice(
value_cache, start, ops.shape(value_update)
)

key_update = ops.where(
cache_update_mask,
key_update,
key_original,
)
value_update = ops.where(
cache_update_mask,
value_update,
value_original,
)

key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
Expand Down
Loading
Loading