In this tutorial, you will learn how to classify images of cats and dogs by using transfer learning from a pre-trained network.
A pre-trained model is a saved network that was previously trained on a large dataset, typically on a large-scale image-classification task. You either use the pretrained model as is or use transfer learning to customize this model to a given task.
The intuition behind transfer learning for image classification is that if a model is trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world. You can then take advantage of these learned feature maps without having to start from scratch by training a large model on a large dataset.
In this notebook, you will try two ways to customize a pretrained model:
You will follow the general machine learning workflow.
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
In this tutorial, you will use a dataset containing several thousand images of cats and dogs. Download and extract a zip file containing the images, then create a
for training and validation using the tf.keras.utils.image_dataset_from_directory
utility. You can learn more about loading images in this tutorial.
_URL = ''
path_to_zip = tf.keras.utils.get_file('', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
IMG_SIZE = (160, 160)
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
Downloading data from 68606236/68606236 [==============================] – 0s 0us/step Found 2000 files belonging to 2 classes.
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
Found 1000 files belonging to 2 classes.
Show the first nine images and labels from the training set:
class_names = train_dataset.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
