Deep learning models have been successful in a wide range of applications, from image classification to natural language processing. However, training these models from scratch can require a large amount of labeled data and computational resources. Fortunately, there are many pre-trained deep learning models available that can be used for a variety of tasks, and the tf.keras.applications
module provides an easy way to load these models in TensorFlow.
In this blog post, we’ll look at a specific implementation of a function that loads pre-trained deep learning models using the tf.keras.applications
module. The code follows a series of steps and uses several parameters to customize the loaded model.
def get_base_model(network: string, alpha=0.35, training_layer=None, pop_layers=0):
# 1. Define the model configuration
kwargs = dict(input_shape=(224, 224, 3), include_top=False, weights="imagenet")
# 2. Add `alpha` value for MobileNet variants
if network.lower().startswith("mobilenet"):
kwargs = {**kwargs, "alpha": alpha}
# 3. Load the pre-trained model
model = getattr(tf.keras.applications, network)(**kwargs)
if not model:
raise Exception("Architecture not supported")
# 4. Customize / reduce the loaded model
if training_layer is not None:
output_layer = model.get_layer(training_layer)
model = tf.keras.Model(model.input, output_layer.output)
elif self.pop_layers > 0:
model = tf.keras.Model(model.input, model.layers[-1 * self.pop_layers].output)
# 5. Disable training on the model
model.trainable = False
return model
Step 1: Define the model configuration
The function starts by defining a dictionary called kwargs, which contains common configuration parameters for all networks. These parameters include the input shape of the image, disabling the fully connected layers at the top of the model, and loading pre-trained weights from the ImageNet dataset. The kwargs dictionary can be used to pass additional arguments to the tf.keras.applications module while loading the pre-trained model.
Step 2: Add alpha
value for MobileNet variants
If the specified network is a variant of MobileNet, the alpha value is added to the kwargs dictionary. This parameter controls the width of the network by scaling the number of filters in each convolutional layer. A value of 1.0 represents the original MobileNet architecture, and values less than 1.0 represent smaller networks.
Step 3: Load the pre-trained model
The pre-trained model with the specified architecture is loaded using the getattr() function from tf.keras.applications and the configuration parameters defined in kwargs. If the specified architecture is not supported, an exception is raised.
Step 4: Customize / reduce the loaded model
The function provides two optional parameters to customize the loaded model:
training_layer
: Specifies the name of the layer to use as the output layer of the model. If this parameter is specified, a new model is created with the same input and output layers as the pre-trained model, but with only the specified layer included.pop_layers
: Specifies the number of layers to remove from the end of the pre-trained model. If this parameter is specified, a new model is created with the same input and output layers as the pre-trained model, but with the specified number of layers removed from the end.
Step 5: Disable training on the model
The trainable attribute of the model is set to False to prevent the pre-trained weights from being modified during training.
Finally, the function returns the pre-trained model. Overall, this specific implementation of a function to load pre-trained deep learning models provides a straightforward way to customize and use pre-trained models in TensorFlow. By using the tf.keras.applications module and providing parameters to customize the loaded model, researchers and developers can save time and resources while still achieving state-of-the-art results in a variety of applications.