Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vision for Gemma3 #2170

Open
wants to merge 35 commits into
base: master
Choose a base branch
from

Conversation

@github-actions github-actions bot added the Gemma Gemma model specific issues label Mar 27, 2025
@abheesht17
Copy link
Collaborator Author

abheesht17 commented Mar 28, 2025

Okay! @mattdangerw / @divyashreepathihalli - this is ready for review, mostly. I'm filling up the doc-strings and the unit tests, but can review the rest.

Also, probably a good idea to refer to this, because a lot of the vision components were added in the previous PR: #2152.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, few NIT comments for clean up.

# `vision_indices_input` to infer it directly.
text_mask_input = keras.Input(
shape=(None,), dtype="int32", name="text_mask"
# Truth be told, this is redundant, and we can infer this from
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove this comment?

# == Branch: vision model, with non-`None` value for `images` ==

# Check: token IDs should not have less than 1, or more than
# `max_images_per_prompt` start of image tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

# == Branch: vision model, with non-`None` value for `images` ==

# Check: token IDs should not have less than 0, or more than
# `max_images_per_prompt` start of image tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Mar 28, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Mar 28, 2025
Comment on lines +7 to +9
START_OF_IMAGE_TOKEN = "<start_of_image>"
IMAGE_PLACEHOLDER_TOKEN = "<img>"
END_OF_IMAGE_TOKEN = "<end_of_image>"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to remove this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are no longer used right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, no longer used

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Haven't done a perfect read through of the preprocessor yet, but left some comments.


# Add these for `Gemma3VITAttention`.
if not self.text_only_model:
target_names += ["query_proj", "value_proj"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it hurt to just always leave these as part of the targets? They just won't match right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, yeah. But let's keep the if condition for clarity

Comment on lines +7 to +9
START_OF_IMAGE_TOKEN = "<start_of_image>"
IMAGE_PLACEHOLDER_TOKEN = "<img>"
END_OF_IMAGE_TOKEN = "<end_of_image>"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are no longer used right?

text_mask: Boolean tensor of shape `(batch_size, seq_length)`.
image_embeddings: tensor. Image embeddings as returned by the
vision encoder (`Gemma3ViT`, usually). Shape:
`(batch_size * num_images_per_prompt, num_vision_tokens_per_image,`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep this indented so the arg list reads right.

def __init__(self, **kwargs):
# Always do image preprocessing in float32
kwargs.pop("dtype", None)
dtype = "float32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this btw? won't the images get converted to the compute dtype later?

Copy link
Collaborator Author

@abheesht17 abheesht17 Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to do standardisation, normalisation in float32 because these ops are sensitive to precision. What do you think - worth keeping?

START_OF_IMAGE_TOKEN = "<start_of_image>"
IMAGE_PLACEHOLDER_TOKEN = "<img>"
END_OF_IMAGE_TOKEN = "<end_of_image>"


@keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but looks like our prompts, responses setup will not work for multi-turn conversations. We should consider how we want that to work in the future.

else original_image_shape[-2]
)

if keras.config.backend() == "torch" and not isinstance(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unsafe, given that in some modes of operations images is user input right? We could also have a np array with a .cpu() function for example. Maybe do images = images.convert_to_numpy(images) which should handle cpu conversion and other cases.

if responses is not None:
responses = tf.expand_dims(responses, axis=0)

# There are 8 cases, based on values of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider if there's some common utilities to refactor and share between call/generate. I see a lot of duplicated code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants