Skip to content

Commit 38794ac

Browse files
Update gating condition to include check for supporting GPUs for flash attention (#2184)
* update gating condition for flash attention * fix test * update utils * fix tests * add on t4 on deny list * address review comments * address comments
1 parent b997444 commit 38794ac

File tree

11 files changed

+85
-29
lines changed

11 files changed

+85
-29
lines changed

Diff for: keras_hub/src/models/gemma/gemma_attention.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
88
from keras_hub.src.utils.keras_utils import clone_initializer
9-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
9+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
10+
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11+
from keras_hub.src.utils.keras_utils import running_on_gpu
1012
from keras_hub.src.utils.keras_utils import running_on_tpu
1113

1214

@@ -106,17 +108,22 @@ def _apply_rope(self, x, start_index):
106108
)
107109
return x
108110

109-
def _can_use_flash_attention(self):
110-
if not has_flash_attention_support():
111+
def _use_fused_attention_op(self):
112+
if not fused_attention_op_available():
111113
return False
112114
if self.dropout > 0.0:
113115
return False
114-
if self.logit_soft_cap is None:
115-
return True
116-
sig = inspect.signature(ops.dot_product_attention)
117-
# We can currently only run soft capped attention for keras >= 3.10
118-
# and only on TPU.
119-
return running_on_tpu() and "attn_logits_soft_cap" in sig.parameters
116+
if running_on_gpu():
117+
# GPU never supports softcap in the fused op.
118+
if self.logit_soft_cap is not None:
119+
return False
120+
return gpu_supports_fused_attention_op()
121+
elif running_on_tpu():
122+
# TPU supports softcap with on keras >= 3.10.
123+
sig = inspect.signature(ops.dot_product_attention)
124+
return "attn_logits_soft_cap" in sig.parameters
125+
else:
126+
return False
120127

121128
def _compute_attention(
122129
self,
@@ -140,7 +147,7 @@ def _compute_attention(
140147
cache_update_index=cache_update_index,
141148
)
142149

143-
if self._can_use_flash_attention():
150+
if self._use_fused_attention_op():
144151
if attention_mask is not None:
145152
attention_mask = ops.expand_dims(attention_mask, axis=1)
146153
attention_mask = ops.cast(attention_mask, dtype="bool")

Diff for: keras_hub/src/models/gemma/gemma_causal_lm_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
1414
from keras_hub.src.tests.test_case import TestCase
15-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
15+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
1616
from keras_hub.src.utils.keras_utils import running_on_gpu
1717

1818

@@ -98,7 +98,10 @@ def test_generate(self):
9898
)
9999

100100
def test_flash_attention_call(self):
101-
if keras.config.backend() != "jax" or not has_flash_attention_support():
101+
if (
102+
keras.config.backend() != "jax"
103+
or not fused_attention_op_available()
104+
):
102105
self.skipTest("`flash_attention` testing requires the Jax backend.")
103106

104107
with patch("keras.src.backend.nn.dot_product_attention") as mock_func:

Diff for: keras_hub/src/models/gemma3/gemma3_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
88
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
99
from keras_hub.src.utils.keras_utils import clone_initializer
10-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
10+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
1111
from keras_hub.src.utils.keras_utils import running_on_tpu
1212

1313

@@ -140,7 +140,7 @@ def _apply_rope(self, x, start_index):
140140
return x
141141

142142
def _can_use_flash_attention(self):
143-
if not has_flash_attention_support():
143+
if not fused_attention_op_available():
144144
return False
145145
if self.dropout > 0.0:
146146
return False

Diff for: keras_hub/src/models/gemma3/gemma3_causal_lm_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
1414
from keras_hub.src.tests.test_case import TestCase
15-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
15+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
1616
from keras_hub.src.utils.keras_utils import running_on_gpu
1717

1818

@@ -77,7 +77,10 @@ def test_text_causal_lm_basics(self):
7777
)
7878

7979
def test_text_flash_attention_call(self):
80-
if keras.config.backend() != "jax" or not has_flash_attention_support():
80+
if (
81+
keras.config.backend() != "jax"
82+
or not fused_attention_op_available()
83+
):
8184
self.skipTest("`flash_attention` testing requires the Jax backend.")
8285

8386
with patch("keras.src.backend.nn.dot_product_attention") as mock_func:

Diff for: keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
77
from keras_hub.src.utils.keras_utils import clone_initializer
8-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
8+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
99

1010

1111
class GPTNeoXAttention(keras.layers.Layer):
@@ -125,7 +125,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
125125
def _compute_attention(
126126
self, query, key, value, attention_mask=None, training=None
127127
):
128-
if has_flash_attention_support() and self.dropout == 0:
128+
if fused_attention_op_available() and self.dropout == 0:
129129
# Use `dot_product_attention` with Flash Attention support if
130130
# available.
131131
if attention_mask is not None:

Diff for: keras_hub/src/models/llama/llama_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
77
from keras_hub.src.utils.keras_utils import clone_initializer
8-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
8+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
99

1010

1111
class LlamaAttention(keras.layers.Layer):
@@ -185,7 +185,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
185185
return self._softmax(attention_scores)
186186

187187
def _compute_attention(self, query, key, value, attention_mask=None):
188-
if has_flash_attention_support():
188+
if fused_attention_op_available():
189189
# Use `dot_product_attention` with Flash Attention support if
190190
# available.
191191
if attention_mask is not None:

Diff for: keras_hub/src/models/mistral/mistral_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
77
from keras_hub.src.utils.keras_utils import clone_initializer
8-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
8+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
99

1010

1111
# This is just a self-attention layer in Mistral. But it can be generalized
@@ -196,7 +196,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
196196
return self._softmax(attention_scores)
197197

198198
def _compute_attention(self, query, key, value, attention_mask=None):
199-
if has_flash_attention_support():
199+
if fused_attention_op_available():
200200
# Use `dot_product_attention` with Flash Attention support if
201201
# available.
202202
if attention_mask is not None:

Diff for: keras_hub/src/models/phi3/phi3_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Phi3SuScaledRotaryEmbedding,
99
)
1010
from keras_hub.src.utils.keras_utils import clone_initializer
11-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
11+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
1212

1313

1414
class Phi3Attention(keras.layers.Layer):
@@ -217,7 +217,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
217217
return self.softmax(attention_scores)
218218

219219
def _compute_attention(self, query, key, value, attention_mask=None):
220-
if has_flash_attention_support():
220+
if fused_attention_op_available():
221221
# Use `dot_product_attention` with Flash Attention support if
222222
# available.
223223
if attention_mask is not None:

Diff for: keras_hub/src/models/qwen/qwen_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
77
from keras_hub.src.utils.keras_utils import clone_initializer
8-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
8+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
99

1010

1111
class QwenAttention(keras.layers.Layer):
@@ -263,7 +263,7 @@ def _compute_attention(
263263
Returns:
264264
attention_output: Output tensor after applying attention.
265265
"""
266-
if has_flash_attention_support():
266+
if fused_attention_op_available():
267267
# Use `dot_product_attention` with Flash Attention support if
268268
# available.
269269
if attention_mask is not None:

Diff for: keras_hub/src/models/stable_diffusion_3/mmdit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
88
from keras_hub.src.models.backbone import Backbone
9+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
910
from keras_hub.src.utils.keras_utils import gelu_approximate
10-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
1111
from keras_hub.src.utils.keras_utils import standardize_data_format
1212

1313

@@ -771,7 +771,7 @@ def build(self, inputs_shape, context_shape, timestep_embedding_shape):
771771
def _compute_attention(self, query, key, value):
772772
batch_size = ops.shape(query)[0]
773773

774-
if has_flash_attention_support():
774+
if fused_attention_op_available():
775775
# Use `dot_product_attention` with Flash Attention support if
776776
# available.
777777
encoded = ops.dot_product_attention(

Diff for: keras_hub/src/utils/keras_utils.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def standardize_data_format(data_format):
5555
return data_format
5656

5757

58-
def has_flash_attention_support():
58+
def fused_attention_op_available():
5959
if (
6060
hasattr(keras.config, "is_flash_attention_enabled")
6161
and keras.config.backend() == "jax"
@@ -104,3 +104,46 @@ def running_on_gpu():
104104
import torch
105105

106106
return torch.cuda.is_available()
107+
108+
109+
def gpu_supports_fused_attention_op():
110+
deny_list = ["T4"]
111+
for denied_gpu in deny_list:
112+
if any(denied_gpu in gpu.upper() for gpu in get_gpu_names()):
113+
return False
114+
return True
115+
116+
117+
def get_gpu_names():
118+
"""Detects and returns the names of available GPUs based on the backend.
119+
120+
Note:
121+
The format and content of the returned GPU names are **not normalized**
122+
and vary significantly depending on the active backend. This function
123+
provides the names as reported by the respective backend's API."
124+
"""
125+
backend = keras.config.backend()
126+
if backend == "jax":
127+
import jax
128+
129+
devices = jax.devices()
130+
131+
return [getattr(d, "device_kind", "") for d in devices]
132+
133+
elif backend == "tensorflow":
134+
import tensorflow as tf
135+
136+
gpus = tf.config.list_physical_devices("GPU")
137+
return [
138+
tf.config.experimental.get_device_details(gpu)["device_name"]
139+
for gpu in gpus
140+
]
141+
elif backend == "torch":
142+
import torch
143+
144+
return [
145+
torch.cuda.get_device_name(i)
146+
for i in range(torch.cuda.device_count())
147+
]
148+
else:
149+
return [""]

0 commit comments

Comments
 (0)