-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add vision for Gemma3 #2170
base: master
Are you sure you want to change the base?
Add vision for Gemma3 #2170
Conversation
Okay! @mattdangerw / @divyashreepathihalli - this is ready for review, mostly. I'm filling up the doc-strings and the unit tests, but can review the rest. Also, probably a good idea to refer to this, because a lot of the vision components were added in the previous PR: #2152. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, few NIT comments for clean up.
# `vision_indices_input` to infer it directly. | ||
text_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="text_mask" | ||
# Truth be told, this is redundant, and we can infer this from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe remove this comment?
# == Branch: vision model, with non-`None` value for `images` == | ||
|
||
# Check: token IDs should not have less than 1, or more than | ||
# `max_images_per_prompt` start of image tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code
# == Branch: vision model, with non-`None` value for `images` == | ||
|
||
# Check: token IDs should not have less than 0, or more than | ||
# `max_images_per_prompt` start of image tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented code
START_OF_IMAGE_TOKEN = "<start_of_image>" | ||
IMAGE_PLACEHOLDER_TOKEN = "<img>" | ||
END_OF_IMAGE_TOKEN = "<end_of_image>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are no longer used right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, no longer used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Haven't done a perfect read through of the preprocessor yet, but left some comments.
|
||
# Add these for `Gemma3VITAttention`. | ||
if not self.text_only_model: | ||
target_names += ["query_proj", "value_proj"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it hurt to just always leave these as part of the targets? They just won't match right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, yeah. But let's keep the if condition for clarity
START_OF_IMAGE_TOKEN = "<start_of_image>" | ||
IMAGE_PLACEHOLDER_TOKEN = "<img>" | ||
END_OF_IMAGE_TOKEN = "<end_of_image>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are no longer used right?
text_mask: Boolean tensor of shape `(batch_size, seq_length)`. | ||
image_embeddings: tensor. Image embeddings as returned by the | ||
vision encoder (`Gemma3ViT`, usually). Shape: | ||
`(batch_size * num_images_per_prompt, num_vision_tokens_per_image,` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep this indented so the arg list reads right.
def __init__(self, **kwargs): | ||
# Always do image preprocessing in float32 | ||
kwargs.pop("dtype", None) | ||
dtype = "float32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this btw? won't the images get converted to the compute dtype later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to do standardisation, normalisation in float32
because these ops are sensitive to precision. What do you think - worth keeping?
START_OF_IMAGE_TOKEN = "<start_of_image>" | ||
IMAGE_PLACEHOLDER_TOKEN = "<img>" | ||
END_OF_IMAGE_TOKEN = "<end_of_image>" | ||
|
||
|
||
@keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but looks like our prompts, responses setup will not work for multi-turn conversations. We should consider how we want that to work in the future.
else original_image_shape[-2] | ||
) | ||
|
||
if keras.config.backend() == "torch" and not isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unsafe, given that in some modes of operations images is user input right? We could also have a np array with a .cpu()
function for example. Maybe do images = images.convert_to_numpy(images)
which should handle cpu conversion and other cases.
if responses is not None: | ||
responses = tf.expand_dims(responses, axis=0) | ||
|
||
# There are 8 cases, based on values of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider if there's some common utilities to refactor and share between call/generate. I see a lot of duplicated code.
Bunch of notebooks to demonstrate how different components work:
CausalLMPreprocessor
: https://colab.research.google.com/drive/1fAb3Rvrw2zRZd5gfVCmB5WW2qdWl_eMN?resourcekey=0-uJD6GoFVgREFVSeMVaXAGA&usp=sharing.generate()
: https://colab.research.google.com/drive/1l2atV5VNt9HYKk-BZ39YcQBcsNG4D94V?resourcekey=0-qX4Xka61fikIygXpFZ-eGA&usp=sharing.fit()
on a randomly initialised model (because even the 4B one cannot fit on A100): https://colab.research.google.com/drive/11Yi9oAtBs9VvrwJPmtY9bsr8ZajwRhl6?resourcekey=0-wKYPYAm9uw51k1T39DfOGg&usp=sharingJAX
generate()
: https://colab.research.google.com/drive/1AR7G3UFawqO9VsNDTtLUrJP-dRTvbkh1?resourcekey=0-4ofXHp_ko1LDzFIG9pYkdQ&usp=sharing.fit()
: https://colab.research.google.com/drive/15IF3wF55EIbuenRB99R_JgbYI8Z_eA9U?resourcekey=0-w0sOujM6cxO1qsvkTXYiDA&usp=sharing