Introduction
In this article, we will go through the tutorial for the Keras implementation of ResNet-50 architecture from scratch. ResNet-50 (Residual Networks) is a deep neural network that is used as a backbone for many computer vision applications like object detection, image segmentation, etc. ResNet was created by the four researchers Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun and it was the winner of the ImageNet challenge in 2015 with an error rate of 3.57%. It also addressed the problem of vanishing gradient that was common in very deep neural networks like itself.
In our Keras implementation of ResNet-50 architecture, we will use the famous Dogs Vs Cats dataset from Kaggle. Moreover, we will use Google Colab to leverage its free GPU for fast training.
- Also Read – Learn Image Classification with Deep Neural Network using Keras
- Also Read – 7 Popular Image Classification Models in ImageNet Challenge (ILSVRC) Competition History
- Also Read – Keras Implementation of VGG16 Architecture from Scratch
Architecture of ResNet
In recent years of the Deep Learning revolution, neural networks have become deeper, with state-of-the-art networks going from just a few layers (e.g., VGG16) to over a hundred layers. The main benefit of a very deep network is that it can represent very complex functions. It can also learn features at many different levels of abstraction, for example, edges (at the lower layers) to very complex features (at the deeper layers) in the case of an image.

However, using a deeper network doesn’t always produce favorable outcomes. A huge barrier to training huge neural networks is the phenomenon of vanishing gradients. Very deep networks often have a gradient signal that goes to zero quickly, thus making gradient descent slow. If we see more specifically, during gradient descent, as you backpropagate from the final layer back to the first layer, you are multiplying by the weight matrix on each step, and thus the gradient can decrease exponentially quickly to zero and hindering the training process.
Skip Connection
In ResNet architecture, a “shortcut” or a “skip connection” allows the gradient to be directly backpropagated to earlier layers:

The image on the top shows the “main path” through the network. The image on the bottom adds a shortcut to the main path. By stacking these ResNet blocks on top of each other, you can form a very deep network.
There are two main types of blocks are used in a ResNet, depending mainly on whether the input/output dimensions are the same or different.
1. Identity Block
The identity block is the standard block used in ResNets and corresponds to the case where the input activation has the same dimension as the output activation.

2. Convolutional Block
We can use this type of block when the input and output dimensions don’t match up. The difference with the identity block is that there is a CONV2D layer in the shortcut path.

ResNet-50
The ResNet-50 model consists of 5 stages each with a convolution and Identity block. Each convolution block has 3 convolution layers and each identity block also has 3 convolution layers. The ResNet-50 has over 23 million trainable parameters.

Dogs Vs Cats Kaggle Dataset

Dogs Vs Cats is a famous dataset from Kaggle that contains 25,000 images of dogs and cats for training a classification model. This dataset is usually used for introductory lessons on convolutional neural networks.
Setting up Google Colab

We have uploaded the dataset on our google drive but before we can use it in Colab we have to mount our google drive directory onto our runtime environment as shown below. This command will generate a URL on which you need to click, authenticate your Google drive account and copy the authorization key over here and press enter.
from google.colab import drive
drive.mount('/content/gdrive')
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly Enter your authorization code: ·········· Mounted at /content/gdrive
ResNet-50 Keras Implementation
Importing libraries and setting up GPU
In [2]:
import cv2
import numpy as np
import os
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import keras
from keras.models import Sequential, Model,load_model
from keras.optimizers import SGD
from keras.callbacks import EarlyStopping,ModelCheckpoint
from google.colab.patches import cv2_imshow
from keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D,MaxPool2D
from keras.preprocessing import image
from keras.initializers import glorot_uniform
Using TensorFlow backend.
K.tensorflow_backend._get_available_gpus()
Define training and testing path of dataset

train_path="/content/gdrive/My Drive/datasets/train"
test_path="/content/gdrive/My Drive/datasets/test"
class_names=os.listdir(train_path)
class_names_test=os.listdir(test_path)
print(class_names)
print(class_names_test)
['cat', 'dog'] ['Cat', 'Dog']
#Sample datasets images
image_dog=cv2.imread("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")
cv2_imshow(image_dog)
image_cat=cv2.imread("/content/gdrive/My Drive/datasets/test/Cat/5.jpg")
cv2_imshow(image_cat)


Preparation of datasets
Sometimes we face issues where we try to load a dataset but there is not enough memory in your machine.
Keras provides the ImageDataGenerator class that defines the configuration for image data preparation and augmentation. The generator progressively loads the images in your dataset, allowing you to work with very large datasets containing thousands or millions of images that may not fit into system memory.
train_datagen = ImageDataGenerator(zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15)
test_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/train",target_size=(224, 224),batch_size=32,shuffle=True,class_mode='binary')
test_generator = test_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/test",target_size=(224,224),batch_size=32,shuffle=False,class_mode='binary')
Found 25000 images belonging to 2 classes. Found 627 images belonging to 2 classes.
Implementation of Identity Block
Let us implement the identity block in Keras –
def identity_block(X, f, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
X = Add()([X, X_shortcut])# SKIP Connection
X = Activation('relu')(X)
return X
Implementation of Convolutional Block
Let us implement the convolutional block in Keras –
def convolutional_block(X, f, filters, stage, block, s=2):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
Implementation of ResNet-50
In this Keras implementation of ResNet -50, we have not defined the fully connected layer in the network. We will see later why.
def ResNet50(input_shape=(224, 224, 3)):
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
X = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name='bn_conv1')(X)
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2))(X)
X = convolutional_block(X, f=3, filters=[64, 64, 256], stage=2, block='a', s=1)
X = identity_block(X, 3, [64, 64, 256], stage=2, block='b')
X = identity_block(X, 3, [64, 64, 256], stage=2, block='c')
X = convolutional_block(X, f=3, filters=[128, 128, 512], stage=3, block='a', s=2)
X = identity_block(X, 3, [128, 128, 512], stage=3, block='b')
X = identity_block(X, 3, [128, 128, 512], stage=3, block='c')
X = identity_block(X, 3, [128, 128, 512], stage=3, block='d')
X = convolutional_block(X, f=3, filters=[256, 256, 1024], stage=4, block='a', s=2)
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='b')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='c')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='d')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='e')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='f')
X = X = convolutional_block(X, f=3, filters=[512, 512, 2048], stage=5, block='a', s=2)
X = identity_block(X, 3, [512, 512, 2048], stage=5, block='b')
X = identity_block(X, 3, [512, 512, 2048], stage=5, block='c')
X = AveragePooling2D(pool_size=(2, 2), padding='same')(X)
model = Model(inputs=X_input, outputs=X, name='ResNet50')
return model
base_model = ResNet50(input_shape=(224, 224, 3))
Quick Concept about Transfer Learning
Deep Convolutional Neural network takes days to train and its training requires lots of computational resources. So to overcome this we are using transfer learning in this Keras implementation of ResNet 50.
Transfer learning is a technique whereby a deep neural network model is first trained on a problem similar to the problem that is being solved. One or more layers from the trained model are then used in a new model trained on the problem of interest. In simple words, transfer learning refers to a process where a model trained on one problem is used in some way on a second related problem.
Here we are manually defining the fully connected layer such that we are able to output required classes as well as we are able to take the leverage of pre-trained model.
Here, we will reuse the model weights from pre-trained models that were developed for standard computer vision benchmark datasets like ImageNet. So we have downloaded pre-trained weights that do not have top layers weights. We have replaced the last layer with our own layer and pre-trained weights do not contain the weights of the new three dense layers. So that’s why we have to download pre-trained layer without top.
headModel = base_model.output
headModel = Flatten()(headModel)
headModel=Dense(256, activation='relu', name='fc1',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel=Dense(128, activation='relu', name='fc2',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel = Dense( 1,activation='sigmoid', name='fc3',kernel_initializer=glorot_uniform(seed=0))(headModel)
Finally, let us create the model which takes input from the last layer of the input layer and outputs from the last layer from the head model
model = Model(inputs=base_model.input, outputs=headModel)
Here is the model summary
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 230, 230, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 zero_padding2d_1[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0]
activation_4[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0]
activation_7[0][0]
__________________________________________________________________________________________________
activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0]
activation_13[0][0]
__________________________________________________________________________________________________
activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0]
activation_16[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0]
activation_19[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0]
activation_25[0][0]
__________________________________________________________________________________________________
activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0]
activation_28[0][0]
__________________________________________________________________________________________________
activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0]
activation_31[0][0]
__________________________________________________________________________________________________
activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0]
activation_34[0][0]
__________________________________________________________________________________________________
activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_37[0][0]
__________________________________________________________________________________________________
activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0]
activation_43[0][0]
__________________________________________________________________________________________________
activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0]
activation_46[0][0]
__________________________________________________________________________________________________
activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 2048) 0 activation_49[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 32768) 0 average_pooling2d_1[0][0]
__________________________________________________________________________________________________
fc1 (Dense) (None, 256) 8388864 flatten_1[0][0]
__________________________________________________________________________________________________
fc2 (Dense) (None, 128) 32896 fc1[0][0]
__________________________________________________________________________________________________
fc3 (Dense) (None, 1) 129 fc2[0][0]
==================================================================================================
Total params: 32,009,601
Trainable params: 31,956,481
Non-trainable params: 53,120
__________________________________________________________________________________________________
Load the pre-trained weights of the model –
(Click here to download the pre-trained weights.)
base_model.load_weights("/content/gdrive/My Drive/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")
As we know that initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. So using these properties of the layer we want to keep the initial layers intact (freeze that layer) and retrain the later layers for our task. This is also termed as the finetuning of a network.
The advantage of finetuning is that we do not have to train the entire layer from scratch and hence the amount of data required for training is not much either. Also, parameters that need to be updated are less and hence the amount of time required for training will also be less.
In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. After that, we go over each layer and select which layers we want to train.
In our case, we are freezing all the convolutional blocks of the model.
for layer in base_model.layers:
layer.trainable = False
Let us print all the layers in our ResNet 50 model. As you see here that up to the last Maxpooling layer it is False which means that during training the parameters of these layers will not be updated and the last three layers have trainable parameter sets to true and hence during training the parameter of these layers gets updated.
for layer in model.layers:
print(layer, layer.trainable)
<keras.engine.input_layer.InputLayer object at 0x7f5bdd4602e8> False <keras.layers.convolutional.ZeroPadding2D object at 0x7f5bd41ecfd0> False <keras.layers.convolutional.Conv2D object at 0x7f5bc042f048> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7f5390> False <keras.layers.core.Activation object at 0x7f5b6c7f5e80> False <keras.layers.pooling.MaxPooling2D object at 0x7f5bc043db00> False <keras.layers.convolutional.Conv2D object at 0x7f5c349fc198> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf9ee80> False <keras.layers.core.Activation object at 0x7f5b6cf9ee48> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf47a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf4f080> False <keras.layers.core.Activation object at 0x7f5b6cf5f518> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf08a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf3b7f0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf0f160> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cecb8d0> False <keras.layers.merge.Add object at 0x7f5b6ced15f8> False <keras.layers.core.Activation object at 0x7f5b6cee9860> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce83898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce8ca90> False <keras.layers.core.Activation object at 0x7f5b6ce8ccf8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce3ef60> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3efd0> False <keras.layers.core.Activation object at 0x7f5b6ce4c6d8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce78e80> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdfe3c8> False <keras.layers.merge.Add object at 0x7f5b6ce06b38> False <keras.layers.core.Activation object at 0x7f5b6ce1d898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce3b898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3b9b0> False <keras.layers.core.Activation object at 0x7f5b6cdc0b70> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cdf2f98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdf2d30> False <keras.layers.core.Activation object at 0x7f5b6cd826a0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cdade48> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdb6390> False <keras.layers.merge.Add object at 0x7f5b6cdbdb00> False <keras.layers.core.Activation object at 0x7f5b6cd52860> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cd6f898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cd78a90> False <keras.layers.core.Activation object at 0x7f5b6cd78c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cd2bf98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cd2bd30> False <keras.layers.core.Activation object at 0x7f5b6cd3b6a0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cce3e80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc989b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cced3c8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cca4a90> False <keras.layers.merge.Add object at 0x7f5b6ccad7b8> False <keras.layers.core.Activation object at 0x7f5b6cc45a20> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc60a58> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cc67c50> False <keras.layers.core.Activation object at 0x7f5b6cc67eb8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc1a9b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cc20080> False <keras.layers.core.Activation object at 0x7f5b6cc28898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cbd4ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cbdb588> False <keras.layers.merge.Add object at 0x7f5b6cbe3cf8> False <keras.layers.core.Activation object at 0x7f5b6cbf8a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb96a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb9ec88> False <keras.layers.core.Activation object at 0x7f5b6cb9ee48> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb509b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb59080> False <keras.layers.core.Activation object at 0x7f5b6cb5e898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb0dba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb115c0> False <keras.layers.merge.Add object at 0x7f5b6cb1cd30> False <keras.layers.core.Activation object at 0x7f5b6cb2ea58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cacda90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cad5a58> False <keras.layers.core.Activation object at 0x7f5b6cad5ac8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca869b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca90080> False <keras.layers.core.Activation object at 0x7f5b6ca98898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca41ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca4a5c0> False <keras.layers.merge.Add object at 0x7f5b6ca52d30> False <keras.layers.core.Activation object at 0x7f5b6ca66a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca02a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca09c88> False <keras.layers.core.Activation object at 0x7f5b6ca09e80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca3d9b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c9c4080> False <keras.layers.core.Activation object at 0x7f5b6c9cc898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c9f8ba8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c9acb70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c97e5c0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c9bac50> False <keras.layers.merge.Add object at 0x7f5b6c941978> False <keras.layers.core.Activation object at 0x7f5b6c95bc50> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c971b70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c979e48> False <keras.layers.core.Activation object at 0x7f5b6c8ff5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c92edd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c934320> False <keras.layers.core.Activation object at 0x7f5b6c93ca90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c8e8c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8ee7b8> False <keras.layers.merge.Add object at 0x7f5b6c8f7f28> False <keras.layers.core.Activation object at 0x7f5b6c894c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c8a9ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8afe80> False <keras.layers.core.Activation object at 0x7f5b6c8bb5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c865dd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c86d320> False <keras.layers.core.Activation object at 0x7f5b6c873a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c822c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8257b8> False <keras.layers.merge.Add object at 0x7f5b6c82def0> False <keras.layers.core.Activation object at 0x7f5b6c789c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c7a0b70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7a8e48> False <keras.layers.core.Activation object at 0x7f5b6c7af5c0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c758da0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7602e8> False <keras.layers.core.Activation object at 0x7f5b6c768a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c714c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c71b7b8> False <keras.layers.merge.Add object at 0x7f5b6c722f28> False <keras.layers.core.Activation object at 0x7f5b6c6bfc88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c6d4ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c6dae80> False <keras.layers.core.Activation object at 0x7f5b6c6e35f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c68cdd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c696320> False <keras.layers.core.Activation object at 0x7f5b6c69fa58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c649c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c6507b8> False <keras.layers.merge.Add object at 0x7f5b6c658f28> False <keras.layers.core.Activation object at 0x7f5b6c675c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c60cba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c613e80> False <keras.layers.core.Activation object at 0x7f5b6c61c5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c5c8dd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c5cf320> False <keras.layers.core.Activation object at 0x7f5b6c5d7a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c57fc88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c5897b8> False <keras.layers.merge.Add object at 0x7f5b6c58ff28> False <keras.layers.core.Activation object at 0x7f5b6c5adc88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c542ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c547e80> False <keras.layers.core.Activation object at 0x7f5b6c5535f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c57cdd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c505320> False <keras.layers.core.Activation object at 0x7f5b6c50ca90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c537c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c4e8da0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4bd7b8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4f8c88> False <keras.layers.merge.Add object at 0x7f5b6c47fba8> False <keras.layers.core.Activation object at 0x7f5b6c49ee80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c4b3d68> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4b8cf8> False <keras.layers.core.Activation object at 0x7f5b6c4427f0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c46e7f0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c474518> False <keras.layers.core.Activation object at 0x7f5b6c3fec88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c426e80> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c42a9b0> False <keras.layers.merge.Add object at 0x7f5b6c42aac8> False <keras.layers.core.Activation object at 0x7f5b6c3d2eb8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c3e8d68> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c3f1cf8> False <keras.layers.core.Activation object at 0x7f5b6c3f97b8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c3a6f98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c3aa4e0> False <keras.layers.core.Activation object at 0x7f5b6c3b6c50> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c35de48> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c365978> False <keras.layers.merge.Add object at 0x7f5b6c365a90> False <keras.layers.core.Activation object at 0x7f5b6c30aeb8> False <keras.layers.pooling.AveragePooling2D object at 0x7f5b6c311fd0> False <keras.layers.core.Flatten object at 0x7f5b6c2c22b0> True <keras.layers.core.Dense object at 0x7f5b6c2c2c18> True <keras.layers.core.Dense object at 0x7f5b6c2c7b00> True <keras.layers.core.Dense object at 0x7f5b6c2d6780> True
We then compile the model using the compile function. This function expects three parameters: the optimizer, the loss function, and the metrics of performance. The optimizer is the stochastic gradient descent algorithm we are going to use. We use the binary_crossentropy loss function since we are doing a binary classification.
In [19]:
opt = SGD(lr=1e-3, momentum=0.9) model.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])
Early Stopping
So to overcome this problem the concept of Early Stoping is used.
In this technique, we can specify an arbitrarily large number of training epochs and stop training once the model performance stops improving on a hold out validation dataset. Keras supports the early stopping of training via a callback called EarlyStopping.
Below are various arguments in EarlyStopping.
- monitor – This allows us to specify the performance measure to monitor in order to end training.
- mode – It is used to specify whether the objective of the chosen metric is to increase maximize or to minimize.
- verbose – To discover the training epoch on which training was stopped, the “verbose” argument can be set to 1. Once stopped, the callback will print the epoch number.
- patience – The first sign of no further improvement may not be the best time to stop training. This is because the model may coast into a plateau of no improvement or even get slightly worse before getting much better. We can account for this by adding a delay to the trigger in terms of the number of epochs on which we would like to see no improvement. This can be done by setting the “patience” argument.
es=EarlyStopping(monitor='val_accuracy', mode='max', verbose=1, patience=20)
Model Check Point
The EarlyStopping callback will stop training once triggered, but the model at the end of training may not be the model with the best performance on the validation dataset.
An additional callback is required that will save the best model observed during training for later use. This is known as the ModelCheckpoint callback.
The ModelCheckpoint callback is flexible in the way it can be used, but in this case, we will use it only to save the best model observed during training as defined by a chosen performance measure on the validation dataset.
mc = ModelCheckpoint('/content/gdrive/My Drive/best_model.h5', monitor='val_accuracy', mode='
Training the Model
H = model.fit_generator(train_generator,validation_data=test_generator,epochs=100,verbose=1,callbacks=[mc,es])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead. Epoch 1/100 785/785 [==============================] - 12435s 16s/step - loss: 0.1190 - accuracy: 0.9516 - val_loss: 1.2386 - val_accuracy: 0.9394 Epoch 2/100 785/785 [==============================] - 360s 459ms/step - loss: 0.0804 - accuracy: 0.9678 - val_loss: 1.2215 - val_accuracy: 0.9601 Epoch 3/100 785/785 [==============================] - 363s 462ms/step - loss: 0.0697 - accuracy: 0.9725 - val_loss: 1.3172 - val_accuracy: 0.9601 Epoch 4/100 785/785 [==============================] - 364s 463ms/step - loss: 0.0628 - accuracy: 0.9750 - val_loss: 1.3791 - val_accuracy: 0.9617 Epoch 5/100 785/785 [==============================] - 365s 465ms/step - loss: 0.0564 - accuracy: 0.9775 - val_loss: 1.4478 - val_accuracy: 0.9569 Epoch 6/100 785/785 [==============================] - 364s 463ms/step - loss: 0.0544 - accuracy: 0.9793 - val_loss: 1.5138 - val_accuracy: 0.9569 Epoch 7/100 785/785 [==============================] - 366s 467ms/step - loss: 0.0511 - accuracy: 0.9806 - val_loss: 1.4760 - val_accuracy: 0.9537 Epoch 8/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0472 - accuracy: 0.9817 - val_loss: 1.5260 - val_accuracy: 0.9601 Epoch 9/100 785/785 [==============================] - 378s 481ms/step - loss: 0.0450 - accuracy: 0.9824 - val_loss: 1.5810 - val_accuracy: 0.9601 Epoch 10/100 785/785 [==============================] - 383s 488ms/step - loss: 0.0437 - accuracy: 0.9832 - val_loss: 1.5863 - val_accuracy: 0.9585 Epoch 11/100 785/785 [==============================] - 382s 487ms/step - loss: 0.0402 - accuracy: 0.9845 - val_loss: 1.5143 - val_accuracy: 0.9601 Epoch 12/100 785/785 [==============================] - 389s 496ms/step - loss: 0.0368 - accuracy: 0.9867 - val_loss: 1.5643 - val_accuracy: 0.9601 Epoch 13/100 785/785 [==============================] - 383s 488ms/step - loss: 0.0341 - accuracy: 0.9870 - val_loss: 1.5968 - val_accuracy: 0.9601 Epoch 14/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0332 - accuracy: 0.9878 - val_loss: 1.6893 - val_accuracy: 0.9601 Epoch 15/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0302 - accuracy: 0.9891 - val_loss: 1.8144 - val_accuracy: 0.9617 Epoch 16/100 785/785 [==============================] - 387s 494ms/step - loss: 0.0296 - accuracy: 0.9891 - val_loss: 1.6615 - val_accuracy: 0.9601 Epoch 17/100 785/785 [==============================] - 386s 492ms/step - loss: 0.0297 - accuracy: 0.9898 - val_loss: 1.8563 - val_accuracy: 0.9569 Epoch 18/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0291 - accuracy: 0.9888 - val_loss: 1.9082 - val_accuracy: 0.9617 Epoch 19/100 785/785 [==============================] - 384s 490ms/step - loss: 0.0284 - accuracy: 0.9894 - val_loss: 1.8754 - val_accuracy: 0.9665 Epoch 20/100 785/785 [==============================] - 384s 489ms/step - loss: 0.0277 - accuracy: 0.9890 - val_loss: 1.8689 - val_accuracy: 0.9633 Epoch 21/100 785/785 [==============================] - 381s 486ms/step - loss: 0.0247 - accuracy: 0.9913 - val_loss: 1.8928 - val_accuracy: 0.9617 Epoch 22/100 785/785 [==============================] - 381s 485ms/step - loss: 0.0230 - accuracy: 0.9917 - val_loss: 1.9322 - val_accuracy: 0.9617 Epoch 23/100 785/785 [==============================] - 378s 481ms/step - loss: 0.0224 - accuracy: 0.9912 - val_loss: 1.8897 - val_accuracy: 0.9633 Epoch 24/100 785/785 [==============================] - 378s 482ms/step - loss: 0.0203 - accuracy: 0.9928 - val_loss: 2.0867 - val_accuracy: 0.9585 Epoch 25/100 785/785 [==============================] - 378s 482ms/step - loss: 0.0220 - accuracy: 0.9916 - val_loss: 2.0524 - val_accuracy: 0.9617 Epoch 26/100 785/785 [==============================] - 379s 483ms/step - loss: 0.0209 - accuracy: 0.9926 - val_loss: 1.8708 - val_accuracy: 0.9601 Epoch 27/100 785/785 [==============================] - 375s 478ms/step - loss: 0.0195 - accuracy: 0.9929 - val_loss: 1.9471 - val_accuracy: 0.9617 Epoch 28/100 785/785 [==============================] - 376s 480ms/step - loss: 0.0154 - accuracy: 0.9946 - val_loss: 2.0850 - val_accuracy: 0.9601 Epoch 29/100 785/785 [==============================] - 376s 479ms/step - loss: 0.0205 - accuracy: 0.9930 - val_loss: 2.0068 - val_accuracy: 0.9665 Epoch 30/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0173 - accuracy: 0.9936 - val_loss: 2.0252 - val_accuracy: 0.9617 Epoch 31/100 785/785 [==============================] - 379s 483ms/step - loss: 0.0203 - accuracy: 0.9930 - val_loss: 2.2049 - val_accuracy: 0.9585 Epoch 32/100 785/785 [==============================] - 375s 478ms/step - loss: 0.0158 - accuracy: 0.9939 - val_loss: 2.2168 - val_accuracy: 0.9617 Epoch 33/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0150 - accuracy: 0.9949 - val_loss: 2.1689 - val_accuracy: 0.9649 Epoch 34/100 785/785 [==============================] - 382s 486ms/step - loss: 0.0155 - accuracy: 0.9945 - val_loss: 2.3065 - val_accuracy: 0.9649 Epoch 35/100 442/785 [===============>..............] - ETA: 2:42 - loss: 0.0158 - accuracy: 0.9945
model.load_weights("/content/gdrive/My Drive/best_model.h5")
Evaluating ResNet 50 model on test datasets
Let us now evaluate the performance of our model on the unseen testing data set. We can see an accuracy of 99%.
model.evaluate_generator(test_generator)
[0.0068423871206876, 0.9931576128793124]
Serialize the ResNet Model
It is the best practice of converting the model into JSON format to save it for the inference program in the future. So we will save our ResNet model as below –
model_json = model.to_json()
with open("/content/gdrive/My Drive/model.json","w") as json_file:
json_file.write(model_json)
Dogs Vs Cat Classification Inference Program
- We implemented the ResNet-50 model with Keras.
- We saved the best training weights of the model in a file for future use.
- We saved the model in JSON format for reusability.
- Load the model that we saved in JSON format earlier.
- Load the weight that we saved after training the model earlier.
- Compile the model.
- Load the image that we want to classify.
- Perform classification.
from keras.models import model_from_json
def predict_(image_path):
#Load the Model from Json File
json_file = open('/content/gdrive/My Drive/model.json', 'r')
model_json_c = json_file.read()
json_file.close()
model_c = model_from_json(model_json_c)
#Load the weights
model_c.load_weights("/content/gdrive/My Drive/best_model.h5")
#Compile the model
opt = SGD(lr=1e-4, momentum=0.9)
model_c.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])
#load the image you want to classify
image = cv2.imread(image_path)
image = cv2.resize(image, (224,224))
cv2_imshow(image)
#predict the image
preds = model_c.predict(np.expand_dims(image, axis=0))[0]
if preds==0:
print("Predicted Label:Cat")
else:
print("Predicted Label: Dog")
Perform Classification
predict_("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")

predict_("/content/gdrive/My Drive/datasets/test/Cat/10.jpg")

predict_("/content/gdrive/My Drive/datasets/test/Cat/7.jpg")

Conclusion
Coming to the end of a long article, we hope you would now know how to implement ResNet 50 with Keras. We used the Dog vs Cat dataset, but you can use just any dataset for creating your own ResNet 50 model with Keras.
-
My name is Sachin Mohan, an undergraduate student of Computer Science and Engineering. My area of interest is ‘Artificial intelligence’ specifically Deep learning and Machine learning. I have attended various online and offline courses on Machine learning and Deep Learning from different national and international institutes My interest toward Machine Learning and deep Learning made me intern at ISRO and also I become the 1st Runner up in TCS EngiNX 2019 contest. I always love to share my knowledge and experience and my philosophy toward learning is "Learning by doing". So thats why I believe in education which have include both theoretical as well as practical knowledge.
View all posts
10 Responses
Hello,
How did you take the pre-trained model weights?
Click here to download the pre-trained weights. Thanks Tejaswi for pointing this out, same has been updated in the article as well.
When we are building our own Resnet-50 model, they why are we again taking pre-built model weights?
Sorry for late reply Tejaswi. This is because we are using Transfer Learning method to train a Resnet model with our dataset and labels from scratch. The standard Resnet model is trained on ImageNet data with 1000 image labels only. If you want to build a ResNet model from scratch that works for your own images and categories we are using Transfer Learning for that purpose which makes use of pretrained model without last layer to start with. Do notice we are taking the pretrained weights of ResNet without Top Layer (notop in filename) . We are then kind of appending our own classifier layer on the top and then train it with our own dataset (Transfer Learning).
Hello, I have tried this code and the prediction is going wrong. It is giving dog for both the cat and dog. What should be done in this case?
How many training data have you taken for each of dog and cat? Also the number of epochs and training accuracy ?
Hello, can you help us in making the code work and also in developing the user interface.
we are using skin dataset, from Kaggle – https://www.kaggle.com/c/siim-isic-melanoma-classification
Unfortunately Tejaswi we do not provide any such service to help in coding.
Hello
when running the model.fit In[21], its telling me that i need to compile using model.compile(optimizer, loss).
It was left by mistake. Line to compile model is added now please check.