Intro to Transfer Learning in TensorFlow

David White
5 min readOct 28, 2021

--

A quick guide to an invaluable deep learning tool

Photo by Nagara Oyodo on Unsplash

To clone and run the exercises shown in this article, view the companion notebook on Kaggle.

Tired of spending hours training your image classifiers only to still have subpar results? Wish there was a better way to build robust neural nets without having to start from scratch every time? Well wish no longer, transfer learning is here and it’s just what you need! In a nutshell, transfer learning allows you to piggyback off of the hard work of deep learning experts who have put in the hard work of training and tuning neural nets for just about any purpose. Today I’ll walk you through how to stand on the shoulders of giants to speed up your deep learning development process.

The Experiment

Training a good image classification model can be particularly arduous making the transfer learning paradigm a must for any competent deep learning practitioner to have in their tool belt. In TensorFlow, this is most easily done using TensorFlow Hub. To see how transfer learning models can drastically speed up your training process, we will build two convolutional networks to classify a set of images; one from scratch and the other using a slimmed-down version of the Inception 3 image classification model used for extracting feature vectors as a base for our own image classifier. We’ll train each model for 10 epochs and see which one produces better results. For our image data, we’ll use this data set from Kaggle containing labelled images of fish.

A red sea bream sampled from our data set.

Preprocessing the Data

Before creating our models, we’ll prepare our data. We first need to separate our data into training and validation sets. For this we can use the split-folders library. We’ll then use the ImageDataGenerator class from the preprocessing package in Keras to load and apply transformations such as rescaling and augmentation to our images without altering them on-disk.

# Before running, make sure to pip install the split-folders package
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub # <--- To load our transfer learning model
from tensorflow.keras.preprocessing.image import ImageDataGeneratorimport splitfolders as sf

# Split the data 80/20 into training and validation sets. Save the split data in a folder called "fish".
sf.ratio("/a-large-scale-fish-dataset/NA_Fish_Dataset", ratio=(0.8, 0.2), output="fish", seed=42)

Next, we’ll set up our ImageDataGenerator objects to load and preprocess our lovely fish photos. We use separate generators for our training and validation data so that we can apply augmentation to our training set while leaving our validation set untouched (aside from rescaling for normalization of course.)

train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=180,
shear_range=0.3,
width_shift_range=0.3,
height_shift_range=0.3,
zoom_range=0.3,
horizontal_flip=True,
vertical_flip=True,
fill_mode="nearest"
)
val_datagen = ImageDataGenerator(rescale=1./255)

# We then grab our data from their respective folders.
train_generator = train_datagen.flow_from_directory(
"fish/train",
target_size=(300, 300),
class_mode="categorical"
)
val_generator = val_datagen.flow_from_directory(
"fish/val",
target_size=(300, 300),
class_mode="categorical"
)
Shrimp photo before augmentation
Augmented shrimp photo

Creating the Models

First we’ll create our “control” model from scratch. We’ll use a basic convolutional net with two convolutional layers each followed by a max pooling layer followed by a 10% dropout layer. We’ll use the same set of dense layers as the top for each of our models.

control_model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(64, (3, 3), input_shape=(300, 300, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(128, (3, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dense(9, activation="softmax")
])
control_model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["acc"]
)
control_model.summary()

As we can see, there are over 174 million parameters to be trained in this model, yikes. Now we’ll create our experimental transfer learning model.

tl_model = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/5",
trainable=False), # Lock the layer from being updated during training
tf.keras.layers.Dense(256, activation="relu"), # Same top layers as control
tf.keras.layers.Dense(9, activation='softmax')
])
tl_model.build([None, 300, 300, 3])
tl_model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["acc"]
)
tl_model.summary()

The transfer learning model only has 22 million params and, since we locked the imported layer, most of them are not going to be trained since the hard work of training that one has been done for us.

Training the Models

Now for the moment of truth. We’ll train each model for 10 epochs and see how each one fares.

control_model.fit(
train_generator,
epochs=10,
validation_data=val_generator
)

We see that after 10 epochs our control model is barely sitting at 20% validation accuracy. Though there doesn’t appear to be much evidence of overfitting at this stage, we’re still far below what most would consider to be an acceptable level of accuracy. Let’s see how our experimental model does.

tl_model.fit(
train_generator,
epochs=10,
validation_data=val_generator
)

After just one (!!) epoch of training, our transfer learning model has almost twice the validation accuracy than that achieved by our handmade model. After just 10 epochs, the new model has over 90% accuracy with almost no evidence of overfitting.

Conclusion

Now we’ve seen the remarkable results that can be gained by building off of the sturdy platforms provided by the hard work of others. While it may feel a little like cheating, transfer learning is an invaluable tool that is not only extremely functional but also highly flexible. I like to think of it like buying a suit off the rack. You can get a pretty good fit just by finding one in your size and you can always have it tailored to get an even better fit. Soon, I’ll dive into how transfer learning models can be improved to get even better results. Stay tuned!

--

--

David White

Software Development | Computer Science | Other Stuff I find interesting