5 Best Ways to Save Model Weights After Specific Number of Epochs in Keras

πŸ’‘ Problem Formulation: In machine learning, it’s essential to save the state of a model at specific milestones during training. For Keras models, users often wish to save the weights after a certain number of epochs to safeguard the training process against interruptions, or for later analysis and comparison. The goal is to periodically checkpoint the model throughout the training process.

Method 1: Callbacks with ModelCheckpoint

An effective way to save model weights after a specific number of epochs is by using the ModelCheckpoint callback provided by Keras. This function specifies the file path to save the model, the frequency of saving, and whether to save only the model’s weights or the full model. It’s an integral part of the Keras API designed for checkpointing models at regular intervals during training.

Here’s an example:

from keras.callbacks import ModelCheckpoint

# Define your model
model = ...

# Create a ModelCheckpoint callback
checkpoint = ModelCheckpoint('path_to_my_model.h5', save_weights_only=True, period=5)

# Fit the model with the callback
model.fit(x_train, y_train, callbacks=[checkpoint])

Output: Model weights saved to ‘path_to_my_model.h5’ every 5 epochs

This code snippet utilizes a ModelCheckpoint callback to save the model’s weights every 5 epochs to a file called ‘path_to_my_model.h5’. By setting save_weights_only to True, only the model’s weights are saved, which saves disk space compared to saving the entire model.

Method 2: LambdaCallback for Custom Behavior

For more control over the saving process, users can utilize the LambdaCallback to define custom behavior during the training process. This method is especially useful when the saving criteria are more complex than just a fixed frequency.

Here’s an example:

from keras.callbacks import LambdaCallback

# Define your model
model = ...

# Define a custom callback for saving weights
save_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: model.save_weights('model_{}.h5'.format(epoch)) if epoch % 5 == 0 else None)

# Fit the model with the custom callback
model.fit(x_train, y_train, callbacks=[save_callback])

Output: Model weights saved to ‘model_{epoch_number}.h5’ every 5 epochs

This code snippet showcases how to use a LambdaCallback to save the model’s weights every 5 epochs. This callback executes a custom function at the end of each epoch, checking if it should save the model’s weights based on the epoch number.

Method 3: Manual Weight Saving Inside Training Loop

In scenarios where the training loop is manually controlled, for example, in custom training loops using tf.GradientTape, weights can be saved manually after a desired number of steps or epochs.

Here’s an example:

import tensorflow as tf

# Define your model
model = ...

# Define training loop
for epoch in range(number_of_epochs):
    # Training step code here...
    # Save the weights manually every 5 epochs
    if epoch % 5 == 0:

Output: Model weights manually saved as ‘manual_save_epoch_{epoch_number}.h5’ every 5 epochs

This snippet demonstrates saving the model’s weights manually within a custom training loop. The model’s weights are saved every 5 epochs to a designated file that includes the epoch number in its name.

Method 4: Integration with TensorBoard

TensorBoard is a visualization toolkit for TensorFlow that includes features for tracking model training. By enabling its ModelCheckpoint integration, you can save weights during training and monitor progress simultaneously.

Here’s an example:

from keras.callbacks import TensorBoard, ModelCheckpoint

# Define your model
model = ...

# Define callbacks for TensorBoard and ModelCheckpoint
tensorboard_callback = TensorBoard(log_dir='./logs')
checkpoint_callback = ModelCheckpoint(filepath='model_weights_{epoch:02d}.h5')

# Fit the model with both callbacks
model.fit(x_train, y_train, callbacks=[tensorboard_callback, checkpoint_callback])

Output: Model weights saved and training progress can be visualized using TensorBoard

By integrating the ModelCheckpoint with TensorBoard through callbacks, we get the advantage of not only saving model weights but also visualizing different aspects of training. TensorBoard provides insights through metrics logging, which complements weight checkpointing.

Bonus One-Liner Method 5: Save Weights After Training

If checkpointing during training is not critical, a simple one-liner can save the model weights once training concludes, which can be a convenient and quick method for less complex scenarios.

Here’s an example:

model.fit(x_train, y_train)
model.save_weights('final_weights.h5')  # Saving after training completes

Output: Final model weights saved in ‘final_weights.h5’

This single line of code will save the weights of the model after the fitting process has completed. It’s the easiest but also the least flexible method as it does not allow for intermediate checkpoints.


  • Method 1: ModelCheckpoint callback. Easy to use and integrated within Keras. Less flexible for complex saving conditions.
  • Method 2: LambdaCallback for custom save conditions. Highly customizable behavior. May require additional code and conditions.
  • Method 3: Manual saving in a custom loop. Complete control of saving logic. Not suitable for typical Keras workflows with built-in fit method.
  • Method 4: Integration with TensorBoard. Allows for simultaneous saving and visualization. Additional setup required for TensorBoard.
  • Method 5: One-liner post-training save. Simplest approach. Does not provide intermediate checkpoints for long training sessions.