5 Best Ways to Use TensorFlow for Fashion MNIST Dataset Predictions in Python

Rate this post

πŸ’‘ Problem Formulation: The Fashion MNIST dataset is a collection of 70,000 grayscale images of 10 fashion categories. Predictive modeling on this dataset involves classifying these images into their respective categories. The input is a 28×28 pixel image and the desired output is a class label (e.g., “Shirt”, “Dress”, “Bag”). TensorFlow, an open-source library for numerical computation, is adept at handling such classification tasks with deep learning models.

Method 1: Building a Simple Neural Network

This method involves constructing a basic neural network using TensorFlow’s high-level Keras API. The sequential model consists of a flattening layer, several dense layers including the output layer with softmax activation for multiclass classification.

Here’s an example:

import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.models import Sequential

# Model building
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Output omitted for brevity

The output of this code will be a compiled TensorFlow model ready for training on the Fashion MNIST dataset.

In this example, we first imported TensorFlow and necessary layers. Then, we initialized a `Sequential` model, and added a `Flatten` layer to convert the 2D image into a 1D array. After that, we added two `Dense` layers where the second one is the output layer with a softmax activation function. Finally, we compiled the model with ‘adam’ optimizer and ‘sparse_categorical_crossentropy’ loss function, gearing it up for the training process.

Method 2: Using CNN for Feature Extraction

Convolutional Neural Networks (CNNs) are particularly effective for image data. A CNN can automatically and adaptively learn spatial hierarchies of features from image data. TensorFlow simplifies the creation of CNNs through its layers and models API.

Here’s an example:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Sequential

# Model building
model = Sequential([
    Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)),
    MaxPooling2D(pool_size=(2,2)),
    Flatten(),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Output omitted for brevity

The output of this code snippet is a TensorFlow model with CNN architecture, ready for training.

Here, the code defines a CNN with `Conv2D` layers followed by `MaxPooling2D` layers for downsampling the input, effectively extracting features. After flattening the pooled features, dense layers follow, providing the classification power. This is all wrapped in a `Sequential` model, compiled with the same settings as in Method 1. CNNs excel in learning the spatial features from image data, crucial for the Fashion MNIST dataset.

Method 3: Implementing Data Augmentation

Data augmentation is a set of techniques to increase the diversity of a dataset by applying random, but realistic, transformations such as rotation, scaling, and horizontal flipping. TensorFlow offers the ImageDataGenerator class to facilitate these transformations.

Here’s an example:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Define a data generator with augmentations
augmentation = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# Output omitted for brevity

The output of this code will be an augmented dataset generator ready to be used for model training.

This snippet creates an instance of `ImageDataGenerator` with specified augmentation parameters. When used in conjunction with the `.fit()` or `.flow()` methods, it will produce augmented image data on the fly during training, which can help the model generalize better and reduce overfitting on the Fashion MNIST dataset.

Method 4: Hyperparameter Tuning with Keras Tuner

Hyperparameter tuning is used to find the optimal set of parameters for a model. TensorFlow integrates with Keras Tuner, a library for hyperparameter tuning that can help us find the best model configuration for the Fashion MNIST dataset.

Here’s an example:

import tensorflow as tf
from kerastuner.tuners import RandomSearch

def build_model(hp):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
    model.add(tf.keras.layers.Dense(units=hp.Int('units',
                                          min_value=32,
                                          max_value=256,
                                          step=32),
                                    activation='relu'))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))
    model.compile(optimizer=tf.keras.optimizers.Adam(hp.Choice('learning_rate', [1e-2, 1e-3, 1e-4])),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

tuner = RandomSearch(
    build_model,
    objective='val_accuracy',
    max_trials=5,
    executions_per_trial=3,
    directory='my_dir',
    project_name='helloworld')

# Output omitted for brevity

This code will set up a hyperparameter tuner using the Keras Tuner library, ready to find the best model configuration.

The `build_model` function creates a model and uses the provided hyperparamters (`hp`) to adjust the number of units in the dense layer and the learning rate of the optimizer. The `RandomSearch` tuner will then test different combinations of these hyperparameters across a number of trials to find the combination that yields the best validation accuracy on the Fashion MNIST data.

Bonus One-Liner Method 5: Transfer Learning with TensorFlow Hub

TensorFlow Hub provides a library for reusable machine learning modules, which can be efficiently used with a minimal amount of code. By utilizing pre-trained models, transfer learning can save time and resources.

Here’s an example:

import tensorflow_hub as hub
import tensorflow as tf

model = tf.keras.Sequential([
    hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", 
                   input_shape=(28,28,1), trainable=False),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Output omitted for brevity

This snippet creates a model with a pre-trained MobileNetV2 as a feature extractor and a dense output layer for classification, compiled and ready for training.

In this code block, we used TensorFlow Hub to incorporate a pre-trained MobileNetV2 model into our `Sequential` model, set as a non-trainable layer for feature extraction. A new `Dense` layer is stacked on top to tailor the model output for the Fashion MNIST dataset classes.

Summary/Discussion

  • Method 1: Simple Neural Network. Easy to set up and understand. Good starting point for beginners. Might be less accurate for complex image classification tasks.
  • Method 2: CNN for Feature Extraction. More sophisticated and suited for image data. Requires more computational power and is more complex to understand than a simple neural network.
  • Method 3: Data Augmentation. Enhances the robustness of the model to variations in new data. Can lead to longer training times due to on-the-fly augmentation processing.
  • Method 4: Hyperparameter Tuning. Can significantly improve model performance. It is computationally intensive and may require a longer time to find the best parameters.
  • Bonus Method 5: Transfer Learning. Leverages pre-trained models for quick setup and potentially improved accuracy with less data. May not be as fine-tuned to specific tasks as other methods that learn directly from the task-specific data.