Skip to content

Commit 5faae37

Browse files
authored
bug fix for num_steps=1 (#2373)
1 parent 9dd547a commit 5faae37

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

keras_cv/models/stable_diffusion/stable_diffusion.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ def generate_image(
210210

211211
# Iterative reverse diffusion stage
212212
num_timesteps = 1000
213-
ratio = (num_timesteps - 1) / (num_steps - 1)
213+
ratio = (
214+
(num_timesteps - 1) / (num_steps - 1)
215+
if num_steps > 1
216+
else num_timesteps
217+
)
214218
timesteps = (np.arange(0, num_steps) * ratio).round().astype(np.int64)
215219

216220
alphas, alphas_prev = self._get_initial_alphas(timesteps)

keras_cv/models/stable_diffusion/stable_diffusion_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ def test_text_tokenizer_golden_value(self):
7272
[49406, 320, 27111, 9038, 320],
7373
)
7474

75+
@pytest.mark.extra_large
76+
def test_num_steps_equal_to_one_no_error(self):
77+
stablediff = StableDiffusion(128, 128)
78+
_ = stablediff.generate_image(
79+
stablediff.encode_text("thou shall not render"),
80+
num_steps=1,
81+
)
82+
7583
@pytest.mark.extra_large
7684
def test_mixed_precision(self):
7785
try:

0 commit comments

Comments
 (0)