In machine learning, transfer learning is a technique where a pre-trained model, developed for one task, is reused as the starting point for another similar task. This can be a huge timesaver, as it reduces the time and computational resources required to train models from scratch. In this blog post, we will discuss how to use this technique to improve performance in image classification tasks, specifically using TensorFlowJS and the MobileNetV2 model. We will also discuss a novel approach that can significantly increase the training speed.
Traditional Transfer Learning in Image Classification
Typically, implementing transfer learning for image classification involves using a base model, like MobileNetV2, and removing the top layers, which are generally responsible for classification. These layers are replaced with some new layers that will be trained to fit our specific task.
The steps usually look like this:
- Get MobileNetV2 without top layers
- Disable training on this base model
- Extend this base model with some dense layers and a GlobalAveragePooling2D layer
- Perform prediction on each image in the dataset using this new model
- Run the training process using this new (sequential) model
This traditional method, while effective, has its disadvantages. The main issue lies in the time it takes for training. As the training must be conducted over the entire model, which can include millions of parameters, it can be a relatively slow process.
A Novel Approach: Activation Dataset and Training Speed Improvement
To address this, we propose a novel technique that separates the learning process into two stages, which can lead to increased training speed. Here are the steps:
- Get MobileNetV2 without top layers, and add a Flatten() layer
- Perform prediction on each image in the dataset using this base model. The output of this prediction will be referred to as the “activation data”
- Create a new dataset from the activation data
- Create a new model with a single 100 unit Dense layer and a GlobalAveragePooling2D layer
- Run the training process using this new (sequential) model, with the activation dataset as input
By separating the learning process into two stages, the time-consuming task of training the entire model is reduced, as the new model to be trained is much simpler and smaller than the original base model.
The idea here is that the base model, with the top layers removed and replaced with a Flatten() layer, serves to extract features from the images. This is often referred to as a feature extractor. These extracted features (the activation data) are then used as input to the new, smaller model.
// Import tensorflow.js
import * as tf from '@tensorflow/tfjs';
// Step 1: Load MobileNetV2 and remove the top layers
let baseModel = tf.mobilenetv2.load();
baseModel = tf.model({inputs: baseModel.inputs, outputs: baseModel.layers[baseModel.layers.length - 2].output});
// Add a Flatten() layer to the base model
baseModel = tf.sequential({layers: [...baseModel.layers, tf.layers.flatten()]});
// Step 2: Perform prediction on each image in the dataset using this base model to get activation data
const activationData = [];
for (let img of images) {
const prediction = baseModel.predict(img);
activationData.push(prediction);
}
// Step 3: Create a new dataset from the activation data
let activationDataset = tf.data.array(activationData);
// Split the activation dataset into training and validation sets (85% - 15%)
const split = Math.floor(activationData.length * 0.85);
const trainData = activationDataset.take(split);
const validationData = activationDataset.skip(split);
// Get the input shape from the last base model layer
const inputShape = baseModel.outputs[0].shape.slice(1);
const inputSize = tf.util.sizeFromShape(inputShape);
// Step 4: Create a new model with a single 100 unit Dense layer and a GlobalAveragePooling2D layer
const newModel = tf.sequential();
newModel.add(tf.layers.dense({units: 100, inputShape: [inputSize], activation: 'relu'}));
newModel.add(tf.layers.globalAveragePooling2d());
// Step 5: Run the training process using this new (sequential) model, with the trainData as input and validationData for validation
newModel.compile({optimizer: 'adam', loss: 'categoricalCrossentropy', metrics: ['accuracy']});
newModel.fit(trainData, labels, {validationData});
// Step 6: Join the base model and the newly created model for exporting
const finalModel = tf.model({inputs: baseModel.inputs, outputs: newModel.apply(baseModel.outputs[0])});
You can see how this technique is implemented in the Teachable Machine here:
Exporting the Model
After training, you might want to use the trained model for prediction on new data. To do this, you can simply join the base model and the newly created model:
- Export the base model (MobileNetV2 + Flatten layer)
- Export the newly created model (Dense + GlobalAveragePooling2D layer)
- Join them into a single model
In this way, the final model you use for predictions still includes the feature extraction capabilities of the base model, coupled with the learned classification abilities of the new model.
Conclusion
This novel approach to transfer learning represents a potentially significant improvement in terms of training speed, without sacrificing accuracy. By leveraging the feature extraction capabilities of a pre-trained model and training a smaller model on the extracted features, we can effectively reduce the computational complexity and resource requirements of the task. This is a shining example of how creativity and innovation can drive improvement and efficiency in machine learning practices.