Keras Implementation of ResNet-50 (Residual Networks) Architecture from Scratch

Introduction

In this article, we will go through the tutorial for the Keras implementation of ResNet-50 architecture from scratch. ResNet-50 (Residual Networks) is a deep neural network that is used as a backbone for many computer vision applications like object detection, image segmentation, etc. ResNet was created by the four researchers Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun and it was the winner of the ImageNet challenge in 2015 with an error rate of 3.57%. It also addressed the problem of vanishing gradient that was common in very deep neural networks like itself.

In our Keras implementation of ResNet-50 architecture, we will use the famous Dogs Vs Cats dataset from Kaggle. Moreover, we will use Google Colab to leverage its free GPU for fast training.

Architecture of ResNet

In recent years of the Deep Learning revolution, neural networks have become deeper, with state-of-the-art networks going from just a few layers (e.g., VGG16) to over a hundred layers. The main benefit of a very deep network is that it can represent very complex functions. It can also learn features at many different levels of abstraction, for example,  edges (at the lower layers) to very complex features (at the deeper layers) in the case of an image.

ResNet Residual Network Keras Implementation Layer Imagenet
Neural Network Layers in ImageNet Challenge (Source)

However, using a deeper network doesn’t always produce favorable outcomes. A huge barrier to training huge neural networks is the phenomenon of vanishing gradients. Very deep networks often have a gradient signal that goes to zero quickly, thus making gradient descent slow. If we see more specifically, during gradient descent, as you backpropagate from the final layer back to the first layer, you are multiplying by the weight matrix on each step, and thus the gradient can decrease exponentially quickly to zero and hindering the training process.

Skip Connection

In ResNet architecture, a “shortcut” or a “skip connection” allows the gradient to be directly backpropagated to earlier layers:

ResNet Residual Network Architecture Skip Connections
Skip Connections in Resnet Architecture (Source)

The image on the top shows the “main path” through the network. The image on the bottom adds a shortcut to the main path. By stacking these ResNet blocks on top of each other, you can form a very deep network.

There are two main types of blocks are used in a ResNet, depending mainly on whether the input/output dimensions are the same or different.

1. Identity Block

The identity block is the standard block used in ResNets and corresponds to the case where the input activation has the same dimension as the output activation.

ResNet Residual Network Keras Implementation Identity Block

2. Convolutional Block

We can use this type of block when the input and output dimensions don’t match up. The difference with the identity block is that there is a CONV2D layer in the shortcut path.

ResNet Residual Network Keras Implementation Identity Block

ResNet-50

The ResNet-50 model consists of 5 stages each with a convolution and Identity block. Each convolution block has 3 convolution layers and each identity block also has 3 convolution layers. The ResNet-50 has over 23 million trainable parameters.

ResNet Keras Implementation Architecture

Dogs Vs Cats Kaggle Dataset

ResNet-50-Residual-Network-Keras-Implementation-Dogs-vs-Cats-Dataset
Dogs vs Cats Dataset

Dogs Vs Cats is a famous dataset from Kaggle that contains 25,000 images of dogs and cats for training a classification model. This dataset is usually used for introductory lessons on convolutional neural networks.

Setting up Google Colab

Google Colab is an online managed Jupyter Notebook environment where you can train deep learning models on GPU. The free plan of Google Colab allows you to train the deep learning model for up to 12 hrs before the runtime disconnects. You have to select runtime as GPU before launching the Jupyter Notebook as shown below –
Google Colab GPU Runtime
Google Colab GPU Runtime

We have uploaded the dataset on our google drive but before we can use it in Colab we have to mount our google drive directory onto our runtime environment as shown below. This command will generate a URL on which you need to click, authenticate your Google drive account and copy the authorization key over here and press enter.

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive

ResNet-50 Keras Implementation

Importing libraries and setting up GPU

In [2]:

import cv2
import numpy as np
import os
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import keras
from keras.models import Sequential, Model,load_model
from keras.optimizers import SGD
from keras.callbacks import EarlyStopping,ModelCheckpoint
from google.colab.patches import cv2_imshow
from keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D,MaxPool2D
from keras.preprocessing import image
from keras.initializers import glorot_uniform
Using TensorFlow backend.
The below code is used to select GPUs –
In [3]:
K.tensorflow_backend._get_available_gpus()
[‘/job:localhost/replica:0/task:0/device:GPU:0’]

Define training and testing path of dataset

It is essential to place the images of both classes in separate subfolders under the train and test folders like the below –
Dog-vs-Cat-Train-Test-Folder
Required Folder Structure
In [4]:
train_path="/content/gdrive/My Drive/datasets/train"
test_path="/content/gdrive/My Drive/datasets/test"
class_names=os.listdir(train_path)
class_names_test=os.listdir(test_path)
Let us verify the class names of folders by printing them.
In [5]:
print(class_names)
print(class_names_test)
['cat', 'dog']
['Cat', 'Dog']
Let us just read some random images for the data set to see what types of images we have. The use of cv2.imshow is disabled in Colab because it causes Jupyter sessions to crash, so as a substitution, we are using cv2_imshow().
In [6]:
#Sample datasets images
image_dog=cv2.imread("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")
cv2_imshow(image_dog)
image_cat=cv2.imread("/content/gdrive/My Drive/datasets/test/Cat/5.jpg")
cv2_imshow(image_cat)

Dog Vs Cat Dataset - Example 1

Dog Vs Cat Dataset - Example 2

Preparation of datasets

Sometimes we face issues where we try to load a dataset but there is not enough memory in your machine.

Keras provides the ImageDataGenerator class that defines the configuration for image data preparation and augmentation. The generator progressively loads the images in your dataset, allowing you to work with very large datasets containing thousands or millions of images that may not fit into system memory.

In [7]:
train_datagen = ImageDataGenerator(zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15)
test_datagen = ImageDataGenerator()
In [8]:
train_generator = train_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/train",target_size=(224, 224),batch_size=32,shuffle=True,class_mode='binary')
test_generator = test_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/test",target_size=(224,224),batch_size=32,shuffle=False,class_mode='binary')
Out [8]:
Found 25000 images belonging to 2 classes.
Found 627 images belonging to 2 classes.

Implementation of Identity Block

Let us implement the identity block in Keras –

In [9]:
def identity_block(X, f, filters, stage, block):
   
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    F1, F2, F3 = filters

    X_shortcut = X
   
    X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
    X = Activation('relu')(X)

    X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
    X = Activation('relu')(X)

    X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)

    X = Add()([X, X_shortcut])# SKIP Connection
    X = Activation('relu')(X)

    return X

Implementation of Convolutional Block

Let us implement the convolutional block in Keras –

In [10]:
def convolutional_block(X, f, filters, stage, block, s=2):
   
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    F1, F2, F3 = filters

    X_shortcut = X

    X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
    X = Activation('relu')(X)

    X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
    X = Activation('relu')(X)

    X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)

    X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
    X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)

    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)

    return X

Implementation of ResNet-50

In this Keras implementation of ResNet -50, we have not defined the fully connected layer in the network. We will see later why.

In [11]:
def ResNet50(input_shape=(224, 224, 3)):

    X_input = Input(input_shape)

    X = ZeroPadding2D((3, 3))(X_input)

    X = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', kernel_initializer=glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis=3, name='bn_conv1')(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((3, 3), strides=(2, 2))(X)

    X = convolutional_block(X, f=3, filters=[64, 64, 256], stage=2, block='a', s=1)
    X = identity_block(X, 3, [64, 64, 256], stage=2, block='b')
    X = identity_block(X, 3, [64, 64, 256], stage=2, block='c')


    X = convolutional_block(X, f=3, filters=[128, 128, 512], stage=3, block='a', s=2)
    X = identity_block(X, 3, [128, 128, 512], stage=3, block='b')
    X = identity_block(X, 3, [128, 128, 512], stage=3, block='c')
    X = identity_block(X, 3, [128, 128, 512], stage=3, block='d')

    X = convolutional_block(X, f=3, filters=[256, 256, 1024], stage=4, block='a', s=2)
    X = identity_block(X, 3, [256, 256, 1024], stage=4, block='b')
    X = identity_block(X, 3, [256, 256, 1024], stage=4, block='c')
    X = identity_block(X, 3, [256, 256, 1024], stage=4, block='d')
    X = identity_block(X, 3, [256, 256, 1024], stage=4, block='e')
    X = identity_block(X, 3, [256, 256, 1024], stage=4, block='f')

    X = X = convolutional_block(X, f=3, filters=[512, 512, 2048], stage=5, block='a', s=2)
    X = identity_block(X, 3, [512, 512, 2048], stage=5, block='b')
    X = identity_block(X, 3, [512, 512, 2048], stage=5, block='c')

    X = AveragePooling2D(pool_size=(2, 2), padding='same')(X)
    
    model = Model(inputs=X_input, outputs=X, name='ResNet50')

    return model
In [12]:
base_model = ResNet50(input_shape=(224, 224, 3))

Quick Concept about Transfer Learning

Deep Convolutional Neural network takes days to train and its training requires lots of computational resources. So to overcome this we are using transfer learning in this Keras implementation of ResNet 50.

Transfer learning is a technique whereby a deep neural network model is first trained on a problem similar to the problem that is being solved. One or more layers from the trained model are then used in a new model trained on the problem of interest. In simple words, transfer learning refers to a process where a model trained on one problem is used in some way on a second related problem.

Here we are manually defining the fully connected layer such that we are able to output required classes as well as we are able to take the leverage of pre-trained model.

Here, we will reuse the model weights from pre-trained models that were developed for standard computer vision benchmark datasets like ImageNet. So we have downloaded pre-trained weights that do not have top layers weights. We have replaced the last layer with our own layer and pre-trained weights do not contain the weights of the new three dense layers. So that’s why we have to download pre-trained layer without top.

In [13]:
headModel = base_model.output
headModel = Flatten()(headModel)
headModel=Dense(256, activation='relu', name='fc1',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel=Dense(128, activation='relu', name='fc2',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel = Dense( 1,activation='sigmoid', name='fc3',kernel_initializer=glorot_uniform(seed=0))(headModel)

Finally, let us create the model which takes input from the last layer of the input layer and outputs from the last layer from the head model

In [14]:
model = Model(inputs=base_model.input, outputs=headModel)

Here is the model summary

In [15]:
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 230, 230, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        zero_padding2d_1[0][0]           
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 55, 55, 64)   0           activation_1[0][0]               
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 55, 55, 64)   4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_2[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_3[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 55, 55, 256)  16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, 55, 55, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 55, 55, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_4[0][0]               
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 55, 55, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_5[0][0]               
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 55, 55, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_6[0][0]               
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, 55, 55, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_4[0][0]               
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 55, 55, 256)  0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_7[0][0]               
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 55, 55, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 55, 55, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_9[0][0]               
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, 55, 55, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_7[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 55, 55, 256)  0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_11[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_12[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 28, 28, 512)  131584      activation_10[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512)  2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, 28, 28, 512)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 28, 28, 512)  0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_13[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_14[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_15[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 28, 28, 512)  0           bn3b_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 28, 28, 512)  0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_16[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_18[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, 28, 28, 512)  0           bn3c_branch2c[0][0]              
                                                                 activation_16[0][0]              
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 28, 28, 512)  0           add_6[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 28, 28, 512)  0           add_7[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_23[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_24[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 14, 14, 1024) 525312      activation_22[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, 14, 14, 1024) 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 14, 14, 1024) 0           add_8[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_25[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_26[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_27[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 14, 14, 1024) 0           bn4b_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, 14, 14, 1024) 0           add_9[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_28[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_29[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_30[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, 14, 14, 1024) 0           bn4c_branch2c[0][0]              
                                                                 activation_28[0][0]              
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 14, 14, 1024) 0           add_10[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_31[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_32[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_33[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
                                                                 activation_31[0][0]              
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 14, 14, 1024) 0           add_11[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_34[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_35[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_36[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 14, 14, 1024) 0           bn4e_branch2c[0][0]              
                                                                 activation_34[0][0]              
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 14, 14, 1024) 0           add_12[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_37[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_38[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_39[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_37[0][0]              
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 14, 14, 1024) 0           add_13[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 7, 7, 512)    524800      activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_41[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_42[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 7, 7, 2048)   2099200     activation_40[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_14 (Add)                    (None, 7, 7, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 7, 7, 2048)   0           add_14[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_43[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_44[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_45[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 7, 7, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_43[0][0]              
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 7, 7, 2048)   0           add_15[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_46[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_47 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_47[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_48 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_48[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_46[0][0]              
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 7, 7, 2048)   0           add_16[0][0]                     
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 2048)   0           activation_49[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 32768)        0           average_pooling2d_1[0][0]        
__________________________________________________________________________________________________
fc1 (Dense)                     (None, 256)          8388864     flatten_1[0][0]                  
__________________________________________________________________________________________________
fc2 (Dense)                     (None, 128)          32896       fc1[0][0]                        
__________________________________________________________________________________________________
fc3 (Dense)                     (None, 1)            129         fc2[0][0]                        
==================================================================================================
Total params: 32,009,601
Trainable params: 31,956,481
Non-trainable params: 53,120
__________________________________________________________________________________________________

Load the pre-trained weights of the model –

(Click here to download the pre-trained weights.)

In [16]:
base_model.load_weights("/content/gdrive/My Drive/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")

As we know that initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. So using these properties of the layer we want to keep the initial layers intact (freeze that layer) and retrain the later layers for our task. This is also termed as the finetuning of a network.

The advantage of finetuning is that we do not have to train the entire layer from scratch and hence the amount of data required for training is not much either. Also, parameters that need to be updated are less and hence the amount of time required for training will also be less.

In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. After that, we go over each layer and select which layers we want to train.
In our case, we are freezing all the convolutional blocks of the model.

In [17]:
for layer in base_model.layers:
    layer.trainable = False

Let us print all the layers in our ResNet 50 model. As you see here that up to the last Maxpooling layer it is False which means that during training the parameters of these layers will not be updated and the last three layers have trainable parameter sets to true and hence during training the parameter of these layers gets updated.

In [18]:
for layer in model.layers:
    print(layer, layer.trainable)
<keras.engine.input_layer.InputLayer object at 0x7f5bdd4602e8> False
<keras.layers.convolutional.ZeroPadding2D object at 0x7f5bd41ecfd0> False
<keras.layers.convolutional.Conv2D object at 0x7f5bc042f048> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c7f5390> False
<keras.layers.core.Activation object at 0x7f5b6c7f5e80> False
<keras.layers.pooling.MaxPooling2D object at 0x7f5bc043db00> False
<keras.layers.convolutional.Conv2D object at 0x7f5c349fc198> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cf9ee80> False
<keras.layers.core.Activation object at 0x7f5b6cf9ee48> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cf47a90> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cf4f080> False
<keras.layers.core.Activation object at 0x7f5b6cf5f518> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cf08a90> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cf3b7f0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cf0f160> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cecb8d0> False
<keras.layers.merge.Add object at 0x7f5b6ced15f8> False
<keras.layers.core.Activation object at 0x7f5b6cee9860> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ce83898> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ce8ca90> False
<keras.layers.core.Activation object at 0x7f5b6ce8ccf8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ce3ef60> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3efd0> False
<keras.layers.core.Activation object at 0x7f5b6ce4c6d8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ce78e80> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cdfe3c8> False
<keras.layers.merge.Add object at 0x7f5b6ce06b38> False
<keras.layers.core.Activation object at 0x7f5b6ce1d898> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ce3b898> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3b9b0> False
<keras.layers.core.Activation object at 0x7f5b6cdc0b70> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cdf2f98> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cdf2d30> False
<keras.layers.core.Activation object at 0x7f5b6cd826a0> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cdade48> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cdb6390> False
<keras.layers.merge.Add object at 0x7f5b6cdbdb00> False
<keras.layers.core.Activation object at 0x7f5b6cd52860> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cd6f898> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cd78a90> False
<keras.layers.core.Activation object at 0x7f5b6cd78c88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cd2bf98> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cd2bd30> False
<keras.layers.core.Activation object at 0x7f5b6cd3b6a0> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cce3e80> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cc989b0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cced3c8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cca4a90> False
<keras.layers.merge.Add object at 0x7f5b6ccad7b8> False
<keras.layers.core.Activation object at 0x7f5b6cc45a20> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cc60a58> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cc67c50> False
<keras.layers.core.Activation object at 0x7f5b6cc67eb8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cc1a9b0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cc20080> False
<keras.layers.core.Activation object at 0x7f5b6cc28898> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cbd4ba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cbdb588> False
<keras.layers.merge.Add object at 0x7f5b6cbe3cf8> False
<keras.layers.core.Activation object at 0x7f5b6cbf8a58> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cb96a90> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cb9ec88> False
<keras.layers.core.Activation object at 0x7f5b6cb9ee48> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cb509b0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cb59080> False
<keras.layers.core.Activation object at 0x7f5b6cb5e898> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cb0dba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cb115c0> False
<keras.layers.merge.Add object at 0x7f5b6cb1cd30> False
<keras.layers.core.Activation object at 0x7f5b6cb2ea58> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6cacda90> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6cad5a58> False
<keras.layers.core.Activation object at 0x7f5b6cad5ac8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ca869b0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ca90080> False
<keras.layers.core.Activation object at 0x7f5b6ca98898> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ca41ba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ca4a5c0> False
<keras.layers.merge.Add object at 0x7f5b6ca52d30> False
<keras.layers.core.Activation object at 0x7f5b6ca66a58> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ca02a90> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6ca09c88> False
<keras.layers.core.Activation object at 0x7f5b6ca09e80> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6ca3d9b0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c9c4080> False
<keras.layers.core.Activation object at 0x7f5b6c9cc898> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c9f8ba8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c9acb70> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c97e5c0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c9bac50> False
<keras.layers.merge.Add object at 0x7f5b6c941978> False
<keras.layers.core.Activation object at 0x7f5b6c95bc50> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c971b70> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c979e48> False
<keras.layers.core.Activation object at 0x7f5b6c8ff5f8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c92edd8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c934320> False
<keras.layers.core.Activation object at 0x7f5b6c93ca90> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c8e8c88> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c8ee7b8> False
<keras.layers.merge.Add object at 0x7f5b6c8f7f28> False
<keras.layers.core.Activation object at 0x7f5b6c894c88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c8a9ba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c8afe80> False
<keras.layers.core.Activation object at 0x7f5b6c8bb5f8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c865dd8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c86d320> False
<keras.layers.core.Activation object at 0x7f5b6c873a90> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c822c88> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c8257b8> False
<keras.layers.merge.Add object at 0x7f5b6c82def0> False
<keras.layers.core.Activation object at 0x7f5b6c789c88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c7a0b70> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c7a8e48> False
<keras.layers.core.Activation object at 0x7f5b6c7af5c0> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c758da0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c7602e8> False
<keras.layers.core.Activation object at 0x7f5b6c768a58> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c714c88> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c71b7b8> False
<keras.layers.merge.Add object at 0x7f5b6c722f28> False
<keras.layers.core.Activation object at 0x7f5b6c6bfc88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c6d4ba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c6dae80> False
<keras.layers.core.Activation object at 0x7f5b6c6e35f8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c68cdd8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c696320> False
<keras.layers.core.Activation object at 0x7f5b6c69fa58> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c649c88> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c6507b8> False
<keras.layers.merge.Add object at 0x7f5b6c658f28> False
<keras.layers.core.Activation object at 0x7f5b6c675c88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c60cba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c613e80> False
<keras.layers.core.Activation object at 0x7f5b6c61c5f8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c5c8dd8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c5cf320> False
<keras.layers.core.Activation object at 0x7f5b6c5d7a90> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c57fc88> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c5897b8> False
<keras.layers.merge.Add object at 0x7f5b6c58ff28> False
<keras.layers.core.Activation object at 0x7f5b6c5adc88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c542ba8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c547e80> False
<keras.layers.core.Activation object at 0x7f5b6c5535f8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c57cdd8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c505320> False
<keras.layers.core.Activation object at 0x7f5b6c50ca90> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c537c88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c4e8da0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c4bd7b8> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c4f8c88> False
<keras.layers.merge.Add object at 0x7f5b6c47fba8> False
<keras.layers.core.Activation object at 0x7f5b6c49ee80> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c4b3d68> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c4b8cf8> False
<keras.layers.core.Activation object at 0x7f5b6c4427f0> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c46e7f0> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c474518> False
<keras.layers.core.Activation object at 0x7f5b6c3fec88> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c426e80> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c42a9b0> False
<keras.layers.merge.Add object at 0x7f5b6c42aac8> False
<keras.layers.core.Activation object at 0x7f5b6c3d2eb8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c3e8d68> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c3f1cf8> False
<keras.layers.core.Activation object at 0x7f5b6c3f97b8> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c3a6f98> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c3aa4e0> False
<keras.layers.core.Activation object at 0x7f5b6c3b6c50> False
<keras.layers.convolutional.Conv2D object at 0x7f5b6c35de48> False
<keras.layers.normalization.BatchNormalization object at 0x7f5b6c365978> False
<keras.layers.merge.Add object at 0x7f5b6c365a90> False
<keras.layers.core.Activation object at 0x7f5b6c30aeb8> False
<keras.layers.pooling.AveragePooling2D object at 0x7f5b6c311fd0> False
<keras.layers.core.Flatten object at 0x7f5b6c2c22b0> True
<keras.layers.core.Dense object at 0x7f5b6c2c2c18> True
<keras.layers.core.Dense object at 0x7f5b6c2c7b00> True
<keras.layers.core.Dense object at 0x7f5b6c2d6780> True

We then compile the model using the compile function. This function expects three parameters: the optimizer, the loss function, and the metrics of performance. The optimizer is the stochastic gradient descent algorithm we are going to use. We use the binary_crossentropy loss function since we are doing a binary classification.

In [19]:

opt = SGD(lr=1e-3, momentum=0.9)
model.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])

Early Stopping

A basic problem that arises in training a neural network is deciding how many epochs a model should be trained. Too many epochs may lead to overfitting of the model and too few epochs may lead to underfitting of the model.

So to overcome this problem the concept of Early Stoping is used.

In this technique, we can specify an arbitrarily large number of training epochs and stop training once the model performance stops improving on a hold out validation dataset. Keras supports the early stopping of training via a callback called EarlyStopping.

Below are various arguments in EarlyStopping.

  • monitor – This allows us to specify the performance measure to monitor in order to end training.
  • mode – It is used to specify whether the objective of the chosen metric is to increase maximize or to minimize.
  • verbose – To discover the training epoch on which training was stopped, the “verbose” argument can be set to 1. Once stopped, the callback will print the epoch number.
  • patience – The first sign of no further improvement may not be the best time to stop training. This is because the model may coast into a plateau of no improvement or even get slightly worse before getting much better. We can account for this by adding a delay to the trigger in terms of the number of epochs on which we would like to see no improvement. This can be done by setting the “patience” argument.
In [19]:
es=EarlyStopping(monitor='val_accuracy', mode='max', verbose=1, patience=20)

Model Check Point

The EarlyStopping callback will stop training once triggered, but the model at the end of training may not be the model with the best performance on the validation dataset.

An additional callback is required that will save the best model observed during training for later use. This is known as the ModelCheckpoint callback.

The ModelCheckpoint callback is flexible in the way it can be used, but in this case, we will use it only to save the best model observed during training as defined by a chosen performance measure on the validation dataset.

In [20]:
mc = ModelCheckpoint('/content/gdrive/My Drive/best_model.h5', monitor='val_accuracy', mode='

Training the Model

In [21]:
H = model.fit_generator(train_generator,validation_data=test_generator,epochs=100,verbose=1,callbacks=[mc,es])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

Epoch 1/100
785/785 [==============================] - 12435s 16s/step - loss: 0.1190 - accuracy: 0.9516 - val_loss: 1.2386 - val_accuracy: 0.9394
Epoch 2/100
785/785 [==============================] - 360s 459ms/step - loss: 0.0804 - accuracy: 0.9678 - val_loss: 1.2215 - val_accuracy: 0.9601
Epoch 3/100
785/785 [==============================] - 363s 462ms/step - loss: 0.0697 - accuracy: 0.9725 - val_loss: 1.3172 - val_accuracy: 0.9601
Epoch 4/100
785/785 [==============================] - 364s 463ms/step - loss: 0.0628 - accuracy: 0.9750 - val_loss: 1.3791 - val_accuracy: 0.9617
Epoch 5/100
785/785 [==============================] - 365s 465ms/step - loss: 0.0564 - accuracy: 0.9775 - val_loss: 1.4478 - val_accuracy: 0.9569
Epoch 6/100
785/785 [==============================] - 364s 463ms/step - loss: 0.0544 - accuracy: 0.9793 - val_loss: 1.5138 - val_accuracy: 0.9569
Epoch 7/100
785/785 [==============================] - 366s 467ms/step - loss: 0.0511 - accuracy: 0.9806 - val_loss: 1.4760 - val_accuracy: 0.9537
Epoch 8/100
785/785 [==============================] - 377s 480ms/step - loss: 0.0472 - accuracy: 0.9817 - val_loss: 1.5260 - val_accuracy: 0.9601
Epoch 9/100
785/785 [==============================] - 378s 481ms/step - loss: 0.0450 - accuracy: 0.9824 - val_loss: 1.5810 - val_accuracy: 0.9601
Epoch 10/100
785/785 [==============================] - 383s 488ms/step - loss: 0.0437 - accuracy: 0.9832 - val_loss: 1.5863 - val_accuracy: 0.9585
Epoch 11/100
785/785 [==============================] - 382s 487ms/step - loss: 0.0402 - accuracy: 0.9845 - val_loss: 1.5143 - val_accuracy: 0.9601
Epoch 12/100
785/785 [==============================] - 389s 496ms/step - loss: 0.0368 - accuracy: 0.9867 - val_loss: 1.5643 - val_accuracy: 0.9601
Epoch 13/100
785/785 [==============================] - 383s 488ms/step - loss: 0.0341 - accuracy: 0.9870 - val_loss: 1.5968 - val_accuracy: 0.9601
Epoch 14/100
785/785 [==============================] - 385s 490ms/step - loss: 0.0332 - accuracy: 0.9878 - val_loss: 1.6893 - val_accuracy: 0.9601
Epoch 15/100
785/785 [==============================] - 385s 490ms/step - loss: 0.0302 - accuracy: 0.9891 - val_loss: 1.8144 - val_accuracy: 0.9617
Epoch 16/100
785/785 [==============================] - 387s 494ms/step - loss: 0.0296 - accuracy: 0.9891 - val_loss: 1.6615 - val_accuracy: 0.9601
Epoch 17/100
785/785 [==============================] - 386s 492ms/step - loss: 0.0297 - accuracy: 0.9898 - val_loss: 1.8563 - val_accuracy: 0.9569
Epoch 18/100
785/785 [==============================] - 385s 490ms/step - loss: 0.0291 - accuracy: 0.9888 - val_loss: 1.9082 - val_accuracy: 0.9617
Epoch 19/100
785/785 [==============================] - 384s 490ms/step - loss: 0.0284 - accuracy: 0.9894 - val_loss: 1.8754 - val_accuracy: 0.9665
Epoch 20/100
785/785 [==============================] - 384s 489ms/step - loss: 0.0277 - accuracy: 0.9890 - val_loss: 1.8689 - val_accuracy: 0.9633
Epoch 21/100
785/785 [==============================] - 381s 486ms/step - loss: 0.0247 - accuracy: 0.9913 - val_loss: 1.8928 - val_accuracy: 0.9617
Epoch 22/100
785/785 [==============================] - 381s 485ms/step - loss: 0.0230 - accuracy: 0.9917 - val_loss: 1.9322 - val_accuracy: 0.9617
Epoch 23/100
785/785 [==============================] - 378s 481ms/step - loss: 0.0224 - accuracy: 0.9912 - val_loss: 1.8897 - val_accuracy: 0.9633
Epoch 24/100
785/785 [==============================] - 378s 482ms/step - loss: 0.0203 - accuracy: 0.9928 - val_loss: 2.0867 - val_accuracy: 0.9585
Epoch 25/100
785/785 [==============================] - 378s 482ms/step - loss: 0.0220 - accuracy: 0.9916 - val_loss: 2.0524 - val_accuracy: 0.9617
Epoch 26/100
785/785 [==============================] - 379s 483ms/step - loss: 0.0209 - accuracy: 0.9926 - val_loss: 1.8708 - val_accuracy: 0.9601
Epoch 27/100
785/785 [==============================] - 375s 478ms/step - loss: 0.0195 - accuracy: 0.9929 - val_loss: 1.9471 - val_accuracy: 0.9617
Epoch 28/100
785/785 [==============================] - 376s 480ms/step - loss: 0.0154 - accuracy: 0.9946 - val_loss: 2.0850 - val_accuracy: 0.9601
Epoch 29/100
785/785 [==============================] - 376s 479ms/step - loss: 0.0205 - accuracy: 0.9930 - val_loss: 2.0068 - val_accuracy: 0.9665
Epoch 30/100
785/785 [==============================] - 377s 480ms/step - loss: 0.0173 - accuracy: 0.9936 - val_loss: 2.0252 - val_accuracy: 0.9617
Epoch 31/100
785/785 [==============================] - 379s 483ms/step - loss: 0.0203 - accuracy: 0.9930 - val_loss: 2.2049 - val_accuracy: 0.9585
Epoch 32/100
785/785 [==============================] - 375s 478ms/step - loss: 0.0158 - accuracy: 0.9939 - val_loss: 2.2168 - val_accuracy: 0.9617
Epoch 33/100
785/785 [==============================] - 377s 480ms/step - loss: 0.0150 - accuracy: 0.9949 - val_loss: 2.1689 - val_accuracy: 0.9649
Epoch 34/100
785/785 [==============================] - 382s 486ms/step - loss: 0.0155 - accuracy: 0.9945 - val_loss: 2.3065 - val_accuracy: 0.9649
Epoch 35/100
442/785 [===============>..............] - ETA: 2:42 - loss: 0.0158 - accuracy: 0.9945
As we can see here that the model training got an early stopping as per what we had expected. We save the weights in the file “best_model.h5“. In the future, we are not required to train the model again we can just load the weight into the model with the below command.
In [22]:
model.load_weights("/content/gdrive/My Drive/best_model.h5")

Evaluating ResNet 50 model on test datasets

Let us now evaluate the performance of our model on the unseen testing data set. We can see an accuracy of 99%.

In [23]:
model.evaluate_generator(test_generator)
Out[23]:
[0.0068423871206876, 0.9931576128793124]

Serialize the ResNet Model

It is the best practice of converting the model into JSON format to save it for the inference program in the future. So we will save our ResNet model as below –

In [24]:
model_json = model.to_json()
with open("/content/gdrive/My Drive/model.json","w") as json_file:
  json_file.write(model_json)

Dogs Vs Cat Classification Inference Program

To recap what we have done till now –
  1. We implemented the ResNet-50 model with Keras.
  2. We saved the best training weights of the model in a file for future use.
  3. We saved the model in JSON format for reusability.
Now it is time to write an inference program that will do the following –
  1. Load the model that we saved in JSON format earlier.
  2. Load the weight that we saved after training the model earlier.
  3. Compile the model.
  4. Load the image that we want to classify.
  5. Perform classification.
For performing these steps we have written a function predict as below.
In [25]:
from keras.models import model_from_json
In [26]:
 def predict_(image_path):
    #Load the Model from Json File
    json_file = open('/content/gdrive/My Drive/model.json', 'r')
    model_json_c = json_file.read()
    json_file.close()
    model_c = model_from_json(model_json_c)
    #Load the weights
    model_c.load_weights("/content/gdrive/My Drive/best_model.h5")
    #Compile the model
    opt = SGD(lr=1e-4, momentum=0.9)
    model_c.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])
    #load the image you want to classify
    image = cv2.imread(image_path)
    image = cv2.resize(image, (224,224))
    cv2_imshow(image)
    #predict the image
    preds = model_c.predict(np.expand_dims(image, axis=0))[0]
    if preds==0:
        print("Predicted Label:Cat")
    else:
        print("Predicted Label: Dog")

Perform Classification

We will now give some random images from Dog and Cat folder to the predict function and see how our Keras implementation of ResNet 50 performed.
It can be seen that the prediction label is matching accurately with our ResNet model. Congratulations on doing the Keras implementation of ResNet 50 successfully.
In [27]:
predict_("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")

Keras Implementation of ResNet Prediction

In [28]:
predict_("/content/gdrive/My Drive/datasets/test/Cat/10.jpg")

Keras Implementation of ResNet-50 Prediction 2

In [29]:
predict_("/content/gdrive/My Drive/datasets/test/Cat/7.jpg")

Keras Implementation of ResNet Prediction-3

Conclusion

Coming to the end of a long article, we hope you would now know how to implement ResNet 50 with Keras. We used the Dog vs Cat dataset, but you can use just any dataset for creating your own ResNet 50 model with Keras.

  • Sachin Mohan

    My name is Sachin Mohan, an undergraduate student of Computer Science and Engineering. My area of interest is ‘Artificial intelligence’ specifically Deep learning and Machine learning. I have attended various online and offline courses on Machine learning and Deep Learning from different national and international institutes My interest toward Machine Learning and deep Learning made me intern at ISRO and also I become the 1st Runner up in TCS EngiNX 2019 contest. I always love to share my knowledge and experience and my philosophy toward learning is "Learning by doing". So thats why I believe in education which have include both theoretical as well as practical knowledge.

    View all posts

Follow Us

10 Responses

    1. Click here to download the pre-trained weights. Thanks Tejaswi for pointing this out, same has been updated in the article as well.

      1. When we are building our own Resnet-50 model, they why are we again taking pre-built model weights?

        1. Sorry for late reply Tejaswi. This is because we are using Transfer Learning method to train a Resnet model with our dataset and labels from scratch. The standard Resnet model is trained on ImageNet data with 1000 image labels only. If you want to build a ResNet model from scratch that works for your own images and categories we are using Transfer Learning for that purpose which makes use of pretrained model without last layer to start with. Do notice we are taking the pretrained weights of ResNet without Top Layer (notop in filename) . We are then kind of appending our own classifier layer on the top and then train it with our own dataset (Transfer Learning).

  1. Hello, I have tried this code and the prediction is going wrong. It is giving dog for both the cat and dog. What should be done in this case?

    1. How many training data have you taken for each of dog and cat? Also the number of epochs and training accuracy ?

  2. Hello
    when running the model.fit In[21], its telling me that i need to compile using model.compile(optimizer, loss).

Leave a Reply

Your email address will not be published. Required fields are marked *