Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpoint used with save_best_only doesn't handle interruptions, even with BackupAndRestore #430

Closed
nicolasnn opened this issue Sep 13, 2022 · 6 comments
Labels
wontfix This will not be worked on

Comments

@nicolasnn
Copy link

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04
  • TensorFlow installed from (source or binary): docker image tensorflow/tensorflow:2.10.0
  • TensorFlow version (use command below): 2.10.0 (the issue is the same for 2.8 and 2.9)
  • Python version: 3.8.10

Describe the problem.
I am using ModelCheckpoint callback to save the best model, combined with BackupAndRestore callback to handle interruptions.
The problem lies when running again a training script after an interruption. The model restored by BackupAndRetore doesn't have the previous value of losses and metrics. Thus, ModelCheckpoint saves the model on the 1st epoch of this new run, whatever the value of loss, it even overwrites the "best" model with a not-as-good model.

Describe the current behavior.

  • I run a training script ✔️
  • The script gets interrupted ✔️
  • When running it again, the model is correctly restored by BackupAndRestore ✔️
  • However, when the model is restored and training resumes, the ModelCheckpoint doesn't behave as expected: on this new run, it saves the model on the first epoch not accounting if the loss improved or not. ❎

Describe the expected behavior.

  • I run a training script ✔️
  • The script gets interrupted ✔️
  • When running it again, the model is correctly restored by BackupAndRestore ✔️
  • The training state is fully restored, including validation loss and metrics, the ModelCheckpoint keeps on doing its job: saving the best model. ✔️

Standalone code to reproduce the issue.

First training

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))

model = keras.Sequential(
    [
        keras.layers.Dense(40, activation="relu"),
        keras.layers.Dense(100, activation="relu"),
        keras.layers.Dense(400, activation="relu"),
        keras.layers.Dense(10, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(1),
    ]
)

model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))

# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)

# Callback that fakes an interruption
class InterruptingCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        if epoch == 4:
            raise RuntimeError('Interrupting!')

model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])

The output is:

Epoch 1/6
476/500 [===========================>..] - ETA: 0s - loss: 2.4351  
Epoch 1: val_loss improved from inf to 2.07992, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.4528 - val_loss: 2.0799
Epoch 2/6
475/500 [===========================>..] - ETA: 0s - loss: 2.2506
Epoch 2: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2690 - val_loss: 2.0819
Epoch 3/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2212
Epoch 3: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2440 - val_loss: 2.0859
Epoch 4/6
486/500 [============================>.] - ETA: 0s - loss: 2.2116
Epoch 4: val_loss did not improve from 2.07992
500/500 [==============================] - 1s 1ms/step - loss: 2.2372 - val_loss: 2.0894
Traceback (most recent call last):
  File "1st_training.py", line 36, in <module>
    model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb, InterruptingCallback()])
  File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "1st_training.py", line 33, in on_epoch_begin
    raise RuntimeError('Interrupting!')
RuntimeError: Interrupting!

2nd training

import tensorflow as tf
from tensorflow import keras
import numpy as np

# Dummy datasets
np.random.seed(12)
train_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(500, 10, 4), np.random.randint(0, 5, (500, 1))))
valid_ds = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10, 4), np.random.randint(0, 5, (100, 1))))

model = keras.Sequential(
    [
        keras.layers.Dense(40, activation="relu"),
        keras.layers.Dense(100, activation="relu"),
        keras.layers.Dense(400, activation="relu"),
        keras.layers.Dense(10, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(1),
    ]
)

model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.005))

# Callbacks
backup_cb = keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup_dir')
ckpt_cb = keras.callbacks.ModelCheckpoint('/tmp/best_model', save_best_only=True, monitor='val_loss', verbose=1)

model.fit(train_ds, epochs=6, validation_data=valid_ds, verbose=1, callbacks=[backup_cb, ckpt_cb])

The output is:

Epoch 5/6
476/500 [===========================>..] - ETA: 0s - loss: 2.2097  
Epoch 5: val_loss improved from inf to 2.09097, saving model to /tmp/best_model
500/500 [==============================] - 2s 3ms/step - loss: 2.2322 - val_loss: 2.0910
Epoch 6/6
500/500 [==============================] - ETA: 0s - loss: 2.2200
Epoch 6: val_loss did not improve from 2.09097
500/500 [==============================] - 1s 1ms/step - loss: 2.2200 - val_loss: 2.0978

The problem lies at val_loss improved from inf to 2.09097, the model restored by BackupAndRetore doesn't restore the previous value of val_loss. The model is initialized with an inf value, thus ModelCheckpoint doesn't fulfill what it is supposed to do and it even overwrites the "best" model with a not-as-good model.

@tilakrayal
Copy link
Collaborator

@gowthamkpr,
I was able to reproduce the issue on tensorflow v2.8, v2.9 and nightly. Kindly find the gist of it here.

@sampathweb
Copy link
Collaborator

@nicolasnn - Thanks for reporting this issue with detailed examples.

I am working on an update to BackupAndRestore callback to address this specific scenario. I will submit a PR in the next few weeks. @rchao - please assign this issue to me.

@rchao rchao assigned sampathweb and unassigned rchao Sep 22, 2022
@rchao
Copy link
Contributor

rchao commented Sep 22, 2022

Thanks Ramesh!

@henrypinkard
Copy link

henrypinkard commented Mar 3, 2023

I have run into the same issue. @sampathweb are you still planning to fix this? @rchao

@fchollet fchollet transferred this issue from keras-team/keras Sep 22, 2023
@hvgazula
Copy link

Hello, just curious if this issue has been fixed or if it is still in the works?

@JyotinderSingh
Copy link
Collaborator

Fixing this requires us to update ModelCheckpoint to provide its state, and to pass ModelCheckpoint object to BackupAndRestore. The team is not in favor of introducing this dependency at this point.

@JyotinderSingh JyotinderSingh added wontfix This will not be worked on and removed keras-team-review-pending labels Apr 3, 2025
@JyotinderSingh JyotinderSingh reopened this Apr 4, 2025
@JyotinderSingh JyotinderSingh closed this as not planned Won't fix, can't repro, duplicate, stale Apr 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

8 participants