5 Best Ways to Implement Transfer Learning in Python Using Keras

Rate this post

πŸ’‘ Problem Formulation: Transfer learning has become a cornerstone in deep learning, allowing developers to leverage pre-trained models to solve similar problems with less data, time, and computational resources. In this article, we focus on how to apply transfer learning using Keras in Python. Imagine you have a dataset of animal images and want to classify them into various species. Instead of training a model from scratch, we can employ a model pre-trained on a large-scale image dataset like ImageNet and fine-tune it to our specific task.

Method 1: Using a Pre-Trained Model as a Feature Extractor

One popular transfer learning approach is to use a pre-trained model as a feature extractor. You remove the output layer of the pretrained network, run your new data through it, and use the output activations as inputs for a new, smaller model that you train from scratch.

Here’s an example:

from keras.applications import VGG16
from keras.layers import Dense, Flatten
from keras.models import Model

# Load the VGG16 model pre-trained on ImageNet data
base_model = VGG16(weights='imagenet', include_top=False)
base_model.trainable = False  # Freeze the layers

# Create a new model on top
x = Flatten()(base_model.output)
x = Dense(1024, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer='adam', loss='categorical_crossentropy')

# Summary of the model to verify the new architecture
model.summary()

Output of the code snippet will show a summary of the new model architecture, indicating that the VGG16 layers are frozen and a new dense layer has been added on top.

This code snippet shows how to create a new model that incorporates VGG16 without its top layer as a feature extractor. The new dense layers added will cater specifically to our task, in this case predicting 10 different classes. This method is efficient since it does not require retraining of the extensive VGG16 network.

Method 2: Fine-Tuning a Pre-Trained Model

Fine-tuning involves tweaking the weights of a pre-trained model by continuing the backpropagation. It is best to fine-tune only the higher layers, as the earlier ones capture more general features.

Here’s an example:

from keras.applications import VGG16
from keras.models import Model
from keras.layers import Dense, Flatten

# Load VGG16 model pre-trained on ImageNet data
base_model = VGG16(weights='imagenet', include_top=False)
x = Flatten()(base_model.output)
x = Dense(1024, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)
# Fine-tune the last convolutional blocks
base_model.trainable = True
for layer in base_model.layers[:-4]:
    layer.trainable = False

model.compile(optimizer='adam', loss='categorical_crossentropy')

# Summary to check the trainable status of the layers
model.summary()

Output will be a summary of the model, showing the trainable status of each layer, with the last convolutional blocks set to be trained.

The snippet explains how to unlock the last few layers of VGG16 for training. This allows the model to fine-tune its parameters to the specifics of the new dataset while keeping most of the original VGG16 layers frozen.

Method 3: Extending a Pre-Trained Model with Custom Layers

An effective strategy in transfer learning is to extend a pre-trained model with custom layers tailored to your specific task. This allows the model to take advantage of the learned feature detectors in the pre-trained base while expanding its capability to the new dataset.

Here’s an example:

from keras.applications import VGG16
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten

# Load VGG16 without the top layer
base_model = VGG16(weights='imagenet', include_top=False)
base_model.trainable = False

# Create a new Sequential model
new_model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

new_model.compile(optimizer='adam', loss='categorical_crossentropy')
new_model.summary()

Output will display a well-defined model architecture summary that includes the pre-trained VGG16 model and the newly added custom layers.

This code adds custom layers on top of base VGG16 model to adapt to the specifics of the new classification task. It’s a quick way to leverage high-level features learned by the pre-trained model while focusing on the new data at hand.

Method 4: Using Pre-Trained Models with Data Augmentation

Combining pre-trained models with data augmentation can greatly improve the generalization of the model to your specific dataset. Augmentation techniques such as rotations, shifts, and flips introduce variety in the training examples without collecting new data.

Here’s an example:

from keras.applications import VGG16
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.preprocessing.image import ImageDataGenerator

# Load VGG16 model pre-trained on ImageNet data
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False

# Create a data generator for augmentation
augmented_data_gen = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)

# Add custom layers on top of the base_model
model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy')

# Summary to display the model architecture
model.summary()

The output will include the summary of the model with appropriate input dimensions set to match the ImageNet dataset and any data augmentation strategy applied to the dataset.

This example integrates data augmentation into the training process by using Keras’ ImageDataGenerator. The model benefits from varied training examples which helps in creating a robust classifier.

Bonus One-Liner Method 5: Loading a Pre-Trained Model and Making Predictions

Sometimes in transfer learning, all you need is to apply a pre-trained model directly to make predictions on the new data without any further modification or training.

Here’s an example:

from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing.image import img_to_array, load_img
import numpy as np

# Load an image and preprocess it
image = load_img('example.jpg', target_size=(224, 224))
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)

# Load VGG16 pre-trained on ImageNet and predict
model = VGG16(weights='imagenet')
prediction = model.predict(image)
print(decode_predictions(prediction, top=3)[0])

The output will be a list of tuples containing the class name, description, and prediction score for the top 3 predicted objects.

This snippet of code shows the simplest way to use a pre-trained model: making a prediction. It loads a pre-trained VGG16 model, preprocesses an input image, and outputs the top 3 predictions for what’s within the image.

Summary/Discussion

  • Method 1: Using a Pre-Trained Model as a Feature Extractor. Strengths: Simplified training, leverage existing architectures. Weaknesses: Not as finely tuned to the specific task.
  • Method 2: Fine-Tuning a Pre-Trained Model. Strengths: Improved task-specific performance. Weaknesses: Requires more computation and caution to avoid overfitting.
  • Method 3: Extending a Pre-Trained Model with Custom Layers. Strengths: Balances high-level feature usage with task-specific adaptation. Weaknesses: Added complexity in designing layers.
  • Method 4: Using Pre-Trained Models with Data Augmentation. Strengths: Enhanced model generalization. Weaknesses: Potentially increased training time.
  • Bonus Method 5: Loading a Pre-Trained Model and Making Predictions. Strengths: Quick and easy use of advanced model capabilities. Weaknesses: No customization and fine-tuning to specific tasks.