Transfer learning consists of taking features learned on one problem, and leveraging them on a new, similar problem. For instance, features from a model that has learned to identify racoons may be useful to kick-start a model meant to identify tanukis.
Transfer learning is usually done for tasks where your dataset has too little data to train a full-scale model from scratch.
The most common incarnation of transfer learning in the context of deep learning is the following workflow:
A last, optional step, is fine-tuning, which consists of unfreezing the entire model you obtained above (or part of it), and re-training it on the new data with a very low learning rate. This can potentially achieve meaningful improvements, by incrementally adapting the pretrained features to the new data.
First, we will go over the Keras trainable
API in detail, which underlies most transfer learning & fine-tuning workflows.
Then, we’ll demonstrate the typical workflow by taking a model pretrained on the ImageNet dataset, and retraining it on the Kaggle “cats vs dogs” classification dataset.
This is adapted from Deep Learning with Python and the 2016 blog post “building powerful image classification models using very little data”.
trainable
attributeLayers & models have three weight attributes:
weights
is the list of all weights variables of the layer.trainable_weights
is the list of those that are meant to be updated (via gradient descent) to minimize the loss during training.non_trainable_weights
is the list of those that aren’t meant to be trained. Typically they are updated by the model during the forward pass.Example: the Dense
layer has 2 trainable weights (kernel & bias)
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 2 non_trainable_weights: 0
In general, all weights are trainable weights. The only built-in layer that has non-trainable weights is the BatchNormalization
layer. It uses non-trainable weights to keep track of the mean and variance of its inputs during training. To learn how to use non-trainable weights in your own custom layers, see the guide to writing new layers from scratch.
Example: the BatchNormalization
layer has 2 trainable weights and 2 non-trainable weights
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4 trainable_weights: 2 non_trainable_weights: 2
Layers & models also feature a boolean attribute trainable
. Its value can be changed. Setting layer.trainable
to False
moves all the layer’s weights from trainable to non-trainable. This is called “freezing” the layer: the state of a frozen layer won’t be updated during training (either when training with fit()
or when training with any custom loop that relies on trainable_weights
to apply gradient updates).
Example: setting trainable
to False
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 0 non_trainable_weights: 2
When a trainable weight becomes non-trainable, its value is no longer updated during training.
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945
Do not confuse the layer.trainable
attribute with the argument training
in layer.__call__()
(which controls whether the layer should run its forward pass in inference mode or training mode). For more information, see the Keras FAQ.
You need to login in order to like this post: click here
YOU MIGHT ALSO LIKE