Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Gemma attention for TPU #2130

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import inspect

import keras
import numpy as np
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import has_flash_attention_support
from keras_hub.src.utils.keras_utils import running_on_tpu


class CachedGemmaAttention(keras.layers.Layer):
Expand Down Expand Up @@ -103,6 +106,18 @@ def _apply_rope(self, x, start_index):
)
return x

def _can_use_flash_attention(self):
if not has_flash_attention_support():
return False
if self.dropout > 0.0:
return False
if self.logit_soft_cap is None:
return True
sig = inspect.signature(ops.dot_product_attention)
# We can currently only run soft capped attention for keras >= 3.10
# and only on TPU.
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters

def _compute_attention(
self,
q,
Expand All @@ -118,27 +133,23 @@ def _compute_attention(
query_normalization = 1 / np.sqrt(
self.hidden_dim // self.num_query_heads
)
use_dot_product_attention = not (
self.dropout > 0.0 or (len(q.shape) != 4)
)
if has_flash_attention_support() and use_dot_product_attention:
if self.dropout > 0.0:
raise ValueError(
"Flash attention does not support dropout. "
"Please set `dropout` to 0.0."
)
if self._can_use_flash_attention():
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")

attention_output = ops.dot_product_attention(
# Only pass soft cap if needed as not all keras versions support.
if self.logit_soft_cap:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}
return ops.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
scale=query_normalization,
**kwargs,
)
return attention_output

q *= ops.cast(query_normalization, dtype=q.dtype)
q_shape = ops.shape(q)
Expand Down
14 changes: 14 additions & 0 deletions keras_hub/src/models/gemma/gemma_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.keras_utils import has_flash_attention_support
from keras_hub.src.utils.keras_utils import running_on_gpu


class GemmaCausalLMTest(TestCase):
Expand Down Expand Up @@ -95,6 +97,18 @@ def test_generate(self):
prompt_ids["padding_mask"][:, :4],
)

def test_flash_attention_call(self):
if keras.config.backend() != "jax" or not has_flash_attention_support():
self.skipTest("`flash_attention` testing requires the Jax backend.")

with patch("keras.src.backend.nn.dot_product_attention") as mock_func:
causal_lm = GemmaCausalLM(**self.init_kwargs)
causal_lm.generate("the quick brown fox")
if running_on_gpu():
mock_func.assert_called()
else:
mock_func.assert_not_called()

def test_generate_with_bfloat16(self):
original_floatx = keras.config.floatx()
keras.config.set_floatx("float16")
Expand Down
32 changes: 32 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ def has_flash_attention_support():
return True
else:
return False


def running_on_tpu():
backend = keras.config.backend()
if backend == "jax":
import jax

devices = jax.devices()
return any(d.platform == "tpu" for d in devices)
elif backend == "tensorflow":
import tensorflow as tf

return bool(tf.config.list_logical_devices("TPU"))
elif backend == "torch":
return False


def running_on_gpu():
backend = keras.config.backend()
if backend == "jax":
import jax

devices = jax.devices()
return any(d.platform == "gpu" for d in devices)
elif backend == "tensorflow":
import tensorflow as tf

return bool(tf.config.list_logical_devices("GPU"))
elif backend == "torch":
import torch

return torch.cuda.is_available()
Loading