Skip to content

Commit 87fca10

Browse files
authored
Revert "Fix dtype bug in image converter (#2147)" (#2180)
1 parent 38794ac commit 87fca10

File tree

4 files changed

+60
-31
lines changed

4 files changed

+60
-31
lines changed

Diff for: keras_hub/src/layers/preprocessing/image_converter.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras_hub.src.utils.preset_utils import get_preset_saver
1717
from keras_hub.src.utils.python_utils import classproperty
1818
from keras_hub.src.utils.tensor_utils import check_bounding_box_support
19+
from keras_hub.src.utils.tensor_utils import in_tf_function
1920
from keras_hub.src.utils.tensor_utils import preprocessing_function
2021

2122

@@ -270,36 +271,45 @@ def call(self, inputs):
270271
else:
271272
x = inputs
272273
if self.scale is not None:
273-
x = x * self._expand_non_channel_dims(self.scale, x)
274+
# If we are scaling always cast to the compute dtype. We can't
275+
# leave things as an int type if we are scaling to [0, 1].
276+
scale = self._expand_non_channel_dims(self.scale, x)
277+
x, scale = self._convert_types(x, scale, self.compute_dtype)
278+
x = x * scale
274279
if self.offset is not None:
275-
x = x + self._expand_non_channel_dims(self.offset, x)
280+
offset = self._expand_non_channel_dims(self.offset, x)
281+
x, offset = self._convert_types(x, offset, x.dtype)
282+
x = x + offset
276283
if isinstance(inputs, dict):
277284
inputs["images"] = x
278285
else:
279286
inputs = x
280287
return inputs
281288

282289
def _expand_non_channel_dims(self, value, inputs):
283-
input_dtype = keras.backend.standardize_dtype(inputs.dtype)
284-
290+
"""Expand non channel dims so value is broadcastable with inputs."""
285291
unbatched = len(ops.shape(inputs)) == 3
286292
channels_first = self.data_format == "channels_first"
287293
if unbatched:
288294
broadcast_dims = (1, 2) if channels_first else (0, 1)
289295
else:
290296
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
291-
# If inputs are not a tensor type, return a numpy array.
292-
# This might happen when running under tf.data.
293-
if ops.is_tensor(inputs):
294-
# preprocessing decorator moves tensors to cpu in torch backend and
295-
# processed on CPU, and then converted back to the appropriate
296-
# device (potentially GPU) after preprocessing.
297-
if keras.backend.backend() == "torch" and self.image_size is None:
298-
return ops.expand_dims(value, broadcast_dims).cpu()
299-
expanded = ops.expand_dims(value, broadcast_dims)
300-
return ops.cast(expanded, input_dtype)
301-
else:
302-
return np.expand_dims(value, broadcast_dims).astype(input_dtype)
297+
# An numpy value will work backend native ops or with tf.data.
298+
return np.expand_dims(value, broadcast_dims)
299+
300+
def _convert_types(self, x, y, dtype):
301+
"""Make sure x and y have the same dtype and are on ths same device."""
302+
if in_tf_function():
303+
# This could happen on any backend if we are running in tf.data.
304+
import tensorflow as tf
305+
306+
return tf.cast(x, dtype), tf.cast(y, dtype)
307+
x = ops.cast(x, dtype)
308+
y = ops.cast(y, dtype)
309+
if keras.backend.backend() == "torch":
310+
# Place on the same device as x (the image).
311+
y = y.to(x.device)
312+
return x, y
303313

304314
def get_config(self):
305315
config = super().get_config()

Diff for: keras_hub/src/layers/preprocessing/image_converter_test.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import keras
55
import numpy as np
66
import pytest
7+
import tensorflow as tf
78
from absl.testing import parameterized
89
from keras import ops
910

@@ -22,6 +23,12 @@ def test_resize_simple(self):
2223
outputs = converter(inputs)
2324
self.assertAllClose(outputs, ops.ones((4, 4, 3)))
2425

26+
def test_resize_dataset(self):
27+
converter = ImageConverter(image_size=(4, 4), scale=1 / 255.0)
28+
ds = tf.data.Dataset.from_tensor_slices(tf.zeros((8, 10, 10, 3)))
29+
batch = ds.batch(2).map(converter).take(1).get_single_element()
30+
self.assertAllClose(batch, tf.zeros((2, 4, 4, 3)))
31+
2532
def test_unbatched(self):
2633
converter = ImageConverter(
2734
image_size=(4, 4),
@@ -35,20 +42,21 @@ def test_unbatched(self):
3542
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.301569)
3643
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.852353)
3744

38-
def test_bfloat16_input(self):
45+
def test_dtypes(self):
46+
converter = ImageConverter(image_size=(4, 4), scale=1.0 / 255.0)
47+
int_image = ops.ones((10, 10, 3), dtype="uint8") * 255
48+
float_image = ops.ones((10, 10, 3), dtype="float64") * 255
49+
self.assertDTypeEqual(converter(int_image), "float32")
50+
self.assertDTypeEqual(converter(float_image), "float32")
51+
self.assertAllClose(converter(int_image), np.ones((4, 4, 3)))
52+
self.assertAllClose(converter(float_image), np.ones((4, 4, 3)))
3953
converter = ImageConverter(
40-
image_size=(4, 4),
41-
scale=(1.0 / 255.0, 0.8 / 255.0, 1.2 / 255.0),
42-
offset=(0.2, -0.1, 0.25),
43-
dtype="bfloat16",
54+
image_size=(4, 4), scale=1.0 / 255.0, dtype="bfloat16"
4455
)
45-
inputs = ops.ones((10, 10, 3)) * 128
46-
inputs = ops.cast(inputs, "bfloat16")
47-
outputs = converter(inputs)
48-
self.assertEqual(ops.shape(outputs), (4, 4, 3))
49-
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.703125)
50-
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * 0.302734)
51-
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.851562)
56+
self.assertDTypeEqual(converter(int_image), "bfloat16")
57+
self.assertDTypeEqual(converter(float_image), "bfloat16")
58+
self.assertAllClose(converter(int_image), np.ones((4, 4, 3)))
59+
self.assertAllClose(converter(float_image), np.ones((4, 4, 3)))
5260

5361
@parameterized.parameters(
5462
(True, False),

Diff for: keras_hub/src/models/vit/vit_image_converter.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,17 @@ def __init__(
5353

5454
@preprocessing_function
5555
def call(self, inputs):
56+
# TODO: Remove this whole function. Why can just use scale and offset
57+
# in the base class.
5658
x = super().call(inputs)
57-
# By default normalize using imagenet mean and std
5859
if self.norm_mean:
59-
x = x - self._expand_non_channel_dims(self.norm_mean, x)
60+
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
61+
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
62+
x = x - norm_mean
6063
if self.norm_std:
61-
x = x / self._expand_non_channel_dims(self.norm_std, x)
64+
norm_std = self._expand_non_channel_dims(self.norm_std, x)
65+
x, norm_std = self._convert_types(x, norm_std, x.dtype)
66+
x = x / norm_std
6267

6368
return x
6469

Diff for: keras_hub/src/utils/tensor_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def no_convert_scope():
2828
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1
2929

3030

31+
def in_tf_function():
32+
if tf is None:
33+
return False
34+
return not tf.executing_eagerly()
35+
36+
3137
def in_no_convert_scope():
3238
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0
3339

0 commit comments

Comments
 (0)