Introduction
In this article, we will go through the tutorial for the Keras implementation of ResNet-50 architecture from scratch. ResNet-50 (Residual Networks) is a deep neural network that is used as a backbone for many computer vision applications like object detection, image segmentation, etc. ResNet was created by the four researchers Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun and it was the winner of the ImageNet challenge in 2015 with an error rate of 3.57%. It also addressed the problem of vanishing gradient that was common in very deep neural networks like itself.
In our Keras implementation of ResNet-50 architecture, we will use the famous Dogs Vs Cats dataset from Kaggle. Moreover, we will use Google Colab to leverage its free GPU for fast training.
- Also Read – Learn Image Classification with Deep Neural Network using Keras
- Also Read – 7 Popular Image Classification Models in ImageNet Challenge (ILSVRC) Competition History
- Also Read – Keras Implementation of VGG16 Architecture from Scratch
Architecture of ResNet
In recent years of the Deep Learning revolution, neural networks have become deeper, with state-of-the-art networks going from just a few layers (e.g., VGG16) to over a hundred layers. The main benefit of a very deep network is that it can represent very complex functions. It can also learn features at many different levels of abstraction, for example, edges (at the lower layers) to very complex features (at the deeper layers) in the case of an image.
However, using a deeper network doesn’t always produce favorable outcomes. A huge barrier to training huge neural networks is the phenomenon of vanishing gradients. Very deep networks often have a gradient signal that goes to zero quickly, thus making gradient descent slow. If we see more specifically, during gradient descent, as you backpropagate from the final layer back to the first layer, you are multiplying by the weight matrix on each step, and thus the gradient can decrease exponentially quickly to zero and hindering the training process.
Skip Connection
In ResNet architecture, a “shortcut” or a “skip connection” allows the gradient to be directly backpropagated to earlier layers:
The image on the top shows the “main path” through the network. The image on the bottom adds a shortcut to the main path. By stacking these ResNet blocks on top of each other, you can form a very deep network.
There are two main types of blocks are used in a ResNet, depending mainly on whether the input/output dimensions are the same or different.
1. Identity Block
The identity block is the standard block used in ResNets and corresponds to the case where the input activation has the same dimension as the output activation.
2. Convolutional Block
We can use this type of block when the input and output dimensions don’t match up. The difference with the identity block is that there is a CONV2D layer in the shortcut path.
ResNet-50
The ResNet-50 model consists of 5 stages each with a convolution and Identity block. Each convolution block has 3 convolution layers and each identity block also has 3 convolution layers. The ResNet-50 has over 23 million trainable parameters.
Dogs Vs Cats Kaggle Dataset
Dogs Vs Cats is a famous dataset from Kaggle that contains 25,000 images of dogs and cats for training a classification model. This dataset is usually used for introductory lessons on convolutional neural networks.
Setting up Google Colab
We have uploaded the dataset on our google drive but before we can use it in Colab we have to mount our google drive directory onto our runtime environment as shown below. This command will generate a URL on which you need to click, authenticate your Google drive account and copy the authorization key over here and press enter.
from google.colab import drive
drive.mount('/content/gdrive')
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly Enter your authorization code: ·········· Mounted at /content/gdrive
ResNet-50 Keras Implementation
Importing libraries and setting up GPU
In [2]:
import cv2
import numpy as np
import os
from keras.preprocessing.image import ImageDataGenerator
from keras import backend as K
import keras
from keras.models import Sequential, Model,load_model
from keras.optimizers import SGD
from keras.callbacks import EarlyStopping,ModelCheckpoint
from google.colab.patches import cv2_imshow
from keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D,MaxPool2D
from keras.preprocessing import image
from keras.initializers import glorot_uniform
Using TensorFlow backend.
K.tensorflow_backend._get_available_gpus()
Define training and testing path of dataset
train_path="/content/gdrive/My Drive/datasets/train"
test_path="/content/gdrive/My Drive/datasets/test"
class_names=os.listdir(train_path)
class_names_test=os.listdir(test_path)
print(class_names)
print(class_names_test)
['cat', 'dog'] ['Cat', 'Dog']
#Sample datasets images
image_dog=cv2.imread("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")
cv2_imshow(image_dog)
image_cat=cv2.imread("/content/gdrive/My Drive/datasets/test/Cat/5.jpg")
cv2_imshow(image_cat)
Preparation of datasets
Sometimes we face issues where we try to load a dataset but there is not enough memory in your machine.
Keras provides the ImageDataGenerator class that defines the configuration for image data preparation and augmentation. The generator progressively loads the images in your dataset, allowing you to work with very large datasets containing thousands or millions of images that may not fit into system memory.
train_datagen = ImageDataGenerator(zoom_range=0.15,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.15)
test_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/train",target_size=(224, 224),batch_size=32,shuffle=True,class_mode='binary')
test_generator = test_datagen.flow_from_directory("/content/gdrive/My Drive/datasets/test",target_size=(224,224),batch_size=32,shuffle=False,class_mode='binary')
Found 25000 images belonging to 2 classes. Found 627 images belonging to 2 classes.
Implementation of Identity Block
Let us implement the identity block in Keras –
def identity_block(X, f, filters, stage, block):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
X = Add()([X, X_shortcut])# SKIP Connection
X = Activation('relu')(X)
return X
Implementation of Convolutional Block
Let us implement the convolutional block in Keras –
def convolutional_block(X, f, filters, stage, block, s=2):
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
X = Activation('relu')(X)
X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
Implementation of ResNet-50
In this Keras implementation of ResNet -50, we have not defined the fully connected layer in the network. We will see later why.
def ResNet50(input_shape=(224, 224, 3)):
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
X = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', kernel_initializer=glorot_uniform(seed=0))(X)
X = BatchNormalization(axis=3, name='bn_conv1')(X)
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2))(X)
X = convolutional_block(X, f=3, filters=[64, 64, 256], stage=2, block='a', s=1)
X = identity_block(X, 3, [64, 64, 256], stage=2, block='b')
X = identity_block(X, 3, [64, 64, 256], stage=2, block='c')
X = convolutional_block(X, f=3, filters=[128, 128, 512], stage=3, block='a', s=2)
X = identity_block(X, 3, [128, 128, 512], stage=3, block='b')
X = identity_block(X, 3, [128, 128, 512], stage=3, block='c')
X = identity_block(X, 3, [128, 128, 512], stage=3, block='d')
X = convolutional_block(X, f=3, filters=[256, 256, 1024], stage=4, block='a', s=2)
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='b')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='c')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='d')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='e')
X = identity_block(X, 3, [256, 256, 1024], stage=4, block='f')
X = X = convolutional_block(X, f=3, filters=[512, 512, 2048], stage=5, block='a', s=2)
X = identity_block(X, 3, [512, 512, 2048], stage=5, block='b')
X = identity_block(X, 3, [512, 512, 2048], stage=5, block='c')
X = AveragePooling2D(pool_size=(2, 2), padding='same')(X)
model = Model(inputs=X_input, outputs=X, name='ResNet50')
return model
base_model = ResNet50(input_shape=(224, 224, 3))
Quick Concept about Transfer Learning
Deep Convolutional Neural network takes days to train and its training requires lots of computational resources. So to overcome this we are using transfer learning in this Keras implementation of ResNet 50.
Transfer learning is a technique whereby a deep neural network model is first trained on a problem similar to the problem that is being solved. One or more layers from the trained model are then used in a new model trained on the problem of interest. In simple words, transfer learning refers to a process where a model trained on one problem is used in some way on a second related problem.
Here we are manually defining the fully connected layer such that we are able to output required classes as well as we are able to take the leverage of pre-trained model.
Here, we will reuse the model weights from pre-trained models that were developed for standard computer vision benchmark datasets like ImageNet. So we have downloaded pre-trained weights that do not have top layers weights. We have replaced the last layer with our own layer and pre-trained weights do not contain the weights of the new three dense layers. So that’s why we have to download pre-trained layer without top.
headModel = base_model.output
headModel = Flatten()(headModel)
headModel=Dense(256, activation='relu', name='fc1',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel=Dense(128, activation='relu', name='fc2',kernel_initializer=glorot_uniform(seed=0))(headModel)
headModel = Dense( 1,activation='sigmoid', name='fc3',kernel_initializer=glorot_uniform(seed=0))(headModel)
Finally, let us create the model which takes input from the last layer of the input layer and outputs from the last layer from the head model
model = Model(inputs=base_model.input, outputs=headModel)
Here is the model summary
model.summary()
Model: "model_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 224, 224, 3) 0 __________________________________________________________________________________________________ zero_padding2d_1 (ZeroPadding2D (None, 230, 230, 3) 0 input_1[0][0] __________________________________________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 64) 9472 zero_padding2d_1[0][0] __________________________________________________________________________________________________ bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0] __________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D) (None, 55, 55, 64) 0 activation_1[0][0] __________________________________________________________________________________________________ res2a_branch2a (Conv2D) (None, 55, 55, 64) 4160 max_pooling2d_1[0][0] __________________________________________________________________________________________________ bn2a_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2a[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 55, 55, 64) 0 bn2a_branch2a[0][0] __________________________________________________________________________________________________ res2a_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_2[0][0] __________________________________________________________________________________________________ bn2a_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2a_branch2b[0][0] __________________________________________________________________________________________________ activation_3 (Activation) (None, 55, 55, 64) 0 bn2a_branch2b[0][0] __________________________________________________________________________________________________ res2a_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_3[0][0] __________________________________________________________________________________________________ res2a_branch1 (Conv2D) (None, 55, 55, 256) 16640 max_pooling2d_1[0][0] __________________________________________________________________________________________________ bn2a_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2a_branch2c[0][0] __________________________________________________________________________________________________ bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256) 1024 res2a_branch1[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 55, 55, 256) 0 bn2a_branch2c[0][0] bn2a_branch1[0][0] __________________________________________________________________________________________________ activation_4 (Activation) (None, 55, 55, 256) 0 add_1[0][0] __________________________________________________________________________________________________ res2b_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_4[0][0] __________________________________________________________________________________________________ bn2b_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2a[0][0] __________________________________________________________________________________________________ activation_5 (Activation) (None, 55, 55, 64) 0 bn2b_branch2a[0][0] __________________________________________________________________________________________________ res2b_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_5[0][0] __________________________________________________________________________________________________ bn2b_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2b_branch2b[0][0] __________________________________________________________________________________________________ activation_6 (Activation) (None, 55, 55, 64) 0 bn2b_branch2b[0][0] __________________________________________________________________________________________________ res2b_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_6[0][0] __________________________________________________________________________________________________ bn2b_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2b_branch2c[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 55, 55, 256) 0 bn2b_branch2c[0][0] activation_4[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 55, 55, 256) 0 add_2[0][0] __________________________________________________________________________________________________ res2c_branch2a (Conv2D) (None, 55, 55, 64) 16448 activation_7[0][0] __________________________________________________________________________________________________ bn2c_branch2a (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2a[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 55, 55, 64) 0 bn2c_branch2a[0][0] __________________________________________________________________________________________________ res2c_branch2b (Conv2D) (None, 55, 55, 64) 36928 activation_8[0][0] __________________________________________________________________________________________________ bn2c_branch2b (BatchNormalizati (None, 55, 55, 64) 256 res2c_branch2b[0][0] __________________________________________________________________________________________________ activation_9 (Activation) (None, 55, 55, 64) 0 bn2c_branch2b[0][0] __________________________________________________________________________________________________ res2c_branch2c (Conv2D) (None, 55, 55, 256) 16640 activation_9[0][0] __________________________________________________________________________________________________ bn2c_branch2c (BatchNormalizati (None, 55, 55, 256) 1024 res2c_branch2c[0][0] __________________________________________________________________________________________________ add_3 (Add) (None, 55, 55, 256) 0 bn2c_branch2c[0][0] activation_7[0][0] __________________________________________________________________________________________________ activation_10 (Activation) (None, 55, 55, 256) 0 add_3[0][0] __________________________________________________________________________________________________ res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_10[0][0] __________________________________________________________________________________________________ bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0] __________________________________________________________________________________________________ activation_11 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0] __________________________________________________________________________________________________ res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_11[0][0] __________________________________________________________________________________________________ bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0] __________________________________________________________________________________________________ activation_12 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0] __________________________________________________________________________________________________ res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_12[0][0] __________________________________________________________________________________________________ res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_10[0][0] __________________________________________________________________________________________________ bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0] __________________________________________________________________________________________________ bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0] __________________________________________________________________________________________________ add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0] bn3a_branch1[0][0] __________________________________________________________________________________________________ activation_13 (Activation) (None, 28, 28, 512) 0 add_4[0][0] __________________________________________________________________________________________________ res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_13[0][0] __________________________________________________________________________________________________ bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0] __________________________________________________________________________________________________ activation_14 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0] __________________________________________________________________________________________________ res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_14[0][0] __________________________________________________________________________________________________ bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0] __________________________________________________________________________________________________ activation_15 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0] __________________________________________________________________________________________________ res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_15[0][0] __________________________________________________________________________________________________ bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0] __________________________________________________________________________________________________ add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0] activation_13[0][0] __________________________________________________________________________________________________ activation_16 (Activation) (None, 28, 28, 512) 0 add_5[0][0] __________________________________________________________________________________________________ res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_16[0][0] __________________________________________________________________________________________________ bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0] __________________________________________________________________________________________________ activation_17 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0] __________________________________________________________________________________________________ res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_17[0][0] __________________________________________________________________________________________________ bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0] __________________________________________________________________________________________________ activation_18 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0] __________________________________________________________________________________________________ res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_18[0][0] __________________________________________________________________________________________________ bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0] __________________________________________________________________________________________________ add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0] activation_16[0][0] __________________________________________________________________________________________________ activation_19 (Activation) (None, 28, 28, 512) 0 add_6[0][0] __________________________________________________________________________________________________ res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_19[0][0] __________________________________________________________________________________________________ bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0] __________________________________________________________________________________________________ activation_20 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0] __________________________________________________________________________________________________ res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_20[0][0] __________________________________________________________________________________________________ bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0] __________________________________________________________________________________________________ activation_21 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0] __________________________________________________________________________________________________ res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_21[0][0] __________________________________________________________________________________________________ bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0] __________________________________________________________________________________________________ add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0] activation_19[0][0] __________________________________________________________________________________________________ activation_22 (Activation) (None, 28, 28, 512) 0 add_7[0][0] __________________________________________________________________________________________________ res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_22[0][0] __________________________________________________________________________________________________ bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0] __________________________________________________________________________________________________ activation_23 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0] __________________________________________________________________________________________________ res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_23[0][0] __________________________________________________________________________________________________ bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0] __________________________________________________________________________________________________ activation_24 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0] __________________________________________________________________________________________________ res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_24[0][0] __________________________________________________________________________________________________ res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_22[0][0] __________________________________________________________________________________________________ bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0] __________________________________________________________________________________________________ bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0] __________________________________________________________________________________________________ add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0] bn4a_branch1[0][0] __________________________________________________________________________________________________ activation_25 (Activation) (None, 14, 14, 1024) 0 add_8[0][0] __________________________________________________________________________________________________ res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_25[0][0] __________________________________________________________________________________________________ bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0] __________________________________________________________________________________________________ activation_26 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0] __________________________________________________________________________________________________ res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_26[0][0] __________________________________________________________________________________________________ bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0] __________________________________________________________________________________________________ activation_27 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0] __________________________________________________________________________________________________ res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_27[0][0] __________________________________________________________________________________________________ bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0] __________________________________________________________________________________________________ add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0] activation_25[0][0] __________________________________________________________________________________________________ activation_28 (Activation) (None, 14, 14, 1024) 0 add_9[0][0] __________________________________________________________________________________________________ res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_28[0][0] __________________________________________________________________________________________________ bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0] __________________________________________________________________________________________________ activation_29 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0] __________________________________________________________________________________________________ res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_29[0][0] __________________________________________________________________________________________________ bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0] __________________________________________________________________________________________________ activation_30 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0] __________________________________________________________________________________________________ res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_30[0][0] __________________________________________________________________________________________________ bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0] __________________________________________________________________________________________________ add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0] activation_28[0][0] __________________________________________________________________________________________________ activation_31 (Activation) (None, 14, 14, 1024) 0 add_10[0][0] __________________________________________________________________________________________________ res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_31[0][0] __________________________________________________________________________________________________ bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0] __________________________________________________________________________________________________ activation_32 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0] __________________________________________________________________________________________________ res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_32[0][0] __________________________________________________________________________________________________ bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0] __________________________________________________________________________________________________ activation_33 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0] __________________________________________________________________________________________________ res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_33[0][0] __________________________________________________________________________________________________ bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0] __________________________________________________________________________________________________ add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0] activation_31[0][0] __________________________________________________________________________________________________ activation_34 (Activation) (None, 14, 14, 1024) 0 add_11[0][0] __________________________________________________________________________________________________ res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_34[0][0] __________________________________________________________________________________________________ bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0] __________________________________________________________________________________________________ activation_35 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0] __________________________________________________________________________________________________ res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_35[0][0] __________________________________________________________________________________________________ bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0] __________________________________________________________________________________________________ activation_36 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0] __________________________________________________________________________________________________ res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_36[0][0] __________________________________________________________________________________________________ bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0] __________________________________________________________________________________________________ add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0] activation_34[0][0] __________________________________________________________________________________________________ activation_37 (Activation) (None, 14, 14, 1024) 0 add_12[0][0] __________________________________________________________________________________________________ res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_37[0][0] __________________________________________________________________________________________________ bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0] __________________________________________________________________________________________________ activation_38 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0] __________________________________________________________________________________________________ res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_38[0][0] __________________________________________________________________________________________________ bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0] __________________________________________________________________________________________________ activation_39 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0] __________________________________________________________________________________________________ res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_39[0][0] __________________________________________________________________________________________________ bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0] __________________________________________________________________________________________________ add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0] activation_37[0][0] __________________________________________________________________________________________________ activation_40 (Activation) (None, 14, 14, 1024) 0 add_13[0][0] __________________________________________________________________________________________________ res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_40[0][0] __________________________________________________________________________________________________ bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0] __________________________________________________________________________________________________ activation_41 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0] __________________________________________________________________________________________________ res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_41[0][0] __________________________________________________________________________________________________ bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0] __________________________________________________________________________________________________ activation_42 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0] __________________________________________________________________________________________________ res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_42[0][0] __________________________________________________________________________________________________ res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_40[0][0] __________________________________________________________________________________________________ bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0] __________________________________________________________________________________________________ bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0] __________________________________________________________________________________________________ add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0] bn5a_branch1[0][0] __________________________________________________________________________________________________ activation_43 (Activation) (None, 7, 7, 2048) 0 add_14[0][0] __________________________________________________________________________________________________ res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_43[0][0] __________________________________________________________________________________________________ bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0] __________________________________________________________________________________________________ activation_44 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0] __________________________________________________________________________________________________ res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_44[0][0] __________________________________________________________________________________________________ bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0] __________________________________________________________________________________________________ activation_45 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0] __________________________________________________________________________________________________ res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_45[0][0] __________________________________________________________________________________________________ bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0] __________________________________________________________________________________________________ add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0] activation_43[0][0] __________________________________________________________________________________________________ activation_46 (Activation) (None, 7, 7, 2048) 0 add_15[0][0] __________________________________________________________________________________________________ res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_46[0][0] __________________________________________________________________________________________________ bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0] __________________________________________________________________________________________________ activation_47 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0] __________________________________________________________________________________________________ res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_47[0][0] __________________________________________________________________________________________________ bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0] __________________________________________________________________________________________________ activation_48 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0] __________________________________________________________________________________________________ res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_48[0][0] __________________________________________________________________________________________________ bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0] __________________________________________________________________________________________________ add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0] activation_46[0][0] __________________________________________________________________________________________________ activation_49 (Activation) (None, 7, 7, 2048) 0 add_16[0][0] __________________________________________________________________________________________________ average_pooling2d_1 (AveragePoo (None, 4, 4, 2048) 0 activation_49[0][0] __________________________________________________________________________________________________ flatten_1 (Flatten) (None, 32768) 0 average_pooling2d_1[0][0] __________________________________________________________________________________________________ fc1 (Dense) (None, 256) 8388864 flatten_1[0][0] __________________________________________________________________________________________________ fc2 (Dense) (None, 128) 32896 fc1[0][0] __________________________________________________________________________________________________ fc3 (Dense) (None, 1) 129 fc2[0][0] ================================================================================================== Total params: 32,009,601 Trainable params: 31,956,481 Non-trainable params: 53,120 __________________________________________________________________________________________________
Load the pre-trained weights of the model –
(Click here to download the pre-trained weights.)
base_model.load_weights("/content/gdrive/My Drive/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")
As we know that initial layers learn very general features and as we go higher up the network, the layers tend to learn patterns more specific to the task it is being trained on. So using these properties of the layer we want to keep the initial layers intact (freeze that layer) and retrain the later layers for our task. This is also termed as the finetuning of a network.
The advantage of finetuning is that we do not have to train the entire layer from scratch and hence the amount of data required for training is not much either. Also, parameters that need to be updated are less and hence the amount of time required for training will also be less.
In Keras, each layer has a parameter called “trainable”. For freezing the weights of a particular layer, we should set this parameter to False, indicating that this layer should not be trained. After that, we go over each layer and select which layers we want to train.
In our case, we are freezing all the convolutional blocks of the model.
for layer in base_model.layers:
layer.trainable = False
Let us print all the layers in our ResNet 50 model. As you see here that up to the last Maxpooling layer it is False which means that during training the parameters of these layers will not be updated and the last three layers have trainable parameter sets to true and hence during training the parameter of these layers gets updated.
for layer in model.layers:
print(layer, layer.trainable)
<keras.engine.input_layer.InputLayer object at 0x7f5bdd4602e8> False <keras.layers.convolutional.ZeroPadding2D object at 0x7f5bd41ecfd0> False <keras.layers.convolutional.Conv2D object at 0x7f5bc042f048> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7f5390> False <keras.layers.core.Activation object at 0x7f5b6c7f5e80> False <keras.layers.pooling.MaxPooling2D object at 0x7f5bc043db00> False <keras.layers.convolutional.Conv2D object at 0x7f5c349fc198> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf9ee80> False <keras.layers.core.Activation object at 0x7f5b6cf9ee48> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf47a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf4f080> False <keras.layers.core.Activation object at 0x7f5b6cf5f518> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf08a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cf3b7f0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cf0f160> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cecb8d0> False <keras.layers.merge.Add object at 0x7f5b6ced15f8> False <keras.layers.core.Activation object at 0x7f5b6cee9860> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce83898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce8ca90> False <keras.layers.core.Activation object at 0x7f5b6ce8ccf8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce3ef60> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3efd0> False <keras.layers.core.Activation object at 0x7f5b6ce4c6d8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce78e80> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdfe3c8> False <keras.layers.merge.Add object at 0x7f5b6ce06b38> False <keras.layers.core.Activation object at 0x7f5b6ce1d898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ce3b898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ce3b9b0> False <keras.layers.core.Activation object at 0x7f5b6cdc0b70> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cdf2f98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdf2d30> False <keras.layers.core.Activation object at 0x7f5b6cd826a0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cdade48> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cdb6390> False <keras.layers.merge.Add object at 0x7f5b6cdbdb00> False <keras.layers.core.Activation object at 0x7f5b6cd52860> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cd6f898> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cd78a90> False <keras.layers.core.Activation object at 0x7f5b6cd78c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cd2bf98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cd2bd30> False <keras.layers.core.Activation object at 0x7f5b6cd3b6a0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cce3e80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc989b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cced3c8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cca4a90> False <keras.layers.merge.Add object at 0x7f5b6ccad7b8> False <keras.layers.core.Activation object at 0x7f5b6cc45a20> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc60a58> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cc67c50> False <keras.layers.core.Activation object at 0x7f5b6cc67eb8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cc1a9b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cc20080> False <keras.layers.core.Activation object at 0x7f5b6cc28898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cbd4ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cbdb588> False <keras.layers.merge.Add object at 0x7f5b6cbe3cf8> False <keras.layers.core.Activation object at 0x7f5b6cbf8a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb96a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb9ec88> False <keras.layers.core.Activation object at 0x7f5b6cb9ee48> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb509b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb59080> False <keras.layers.core.Activation object at 0x7f5b6cb5e898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cb0dba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cb115c0> False <keras.layers.merge.Add object at 0x7f5b6cb1cd30> False <keras.layers.core.Activation object at 0x7f5b6cb2ea58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6cacda90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6cad5a58> False <keras.layers.core.Activation object at 0x7f5b6cad5ac8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca869b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca90080> False <keras.layers.core.Activation object at 0x7f5b6ca98898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca41ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca4a5c0> False <keras.layers.merge.Add object at 0x7f5b6ca52d30> False <keras.layers.core.Activation object at 0x7f5b6ca66a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca02a90> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6ca09c88> False <keras.layers.core.Activation object at 0x7f5b6ca09e80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6ca3d9b0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c9c4080> False <keras.layers.core.Activation object at 0x7f5b6c9cc898> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c9f8ba8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c9acb70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c97e5c0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c9bac50> False <keras.layers.merge.Add object at 0x7f5b6c941978> False <keras.layers.core.Activation object at 0x7f5b6c95bc50> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c971b70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c979e48> False <keras.layers.core.Activation object at 0x7f5b6c8ff5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c92edd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c934320> False <keras.layers.core.Activation object at 0x7f5b6c93ca90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c8e8c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8ee7b8> False <keras.layers.merge.Add object at 0x7f5b6c8f7f28> False <keras.layers.core.Activation object at 0x7f5b6c894c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c8a9ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8afe80> False <keras.layers.core.Activation object at 0x7f5b6c8bb5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c865dd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c86d320> False <keras.layers.core.Activation object at 0x7f5b6c873a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c822c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c8257b8> False <keras.layers.merge.Add object at 0x7f5b6c82def0> False <keras.layers.core.Activation object at 0x7f5b6c789c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c7a0b70> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7a8e48> False <keras.layers.core.Activation object at 0x7f5b6c7af5c0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c758da0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c7602e8> False <keras.layers.core.Activation object at 0x7f5b6c768a58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c714c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c71b7b8> False <keras.layers.merge.Add object at 0x7f5b6c722f28> False <keras.layers.core.Activation object at 0x7f5b6c6bfc88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c6d4ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c6dae80> False <keras.layers.core.Activation object at 0x7f5b6c6e35f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c68cdd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c696320> False <keras.layers.core.Activation object at 0x7f5b6c69fa58> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c649c88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c6507b8> False <keras.layers.merge.Add object at 0x7f5b6c658f28> False <keras.layers.core.Activation object at 0x7f5b6c675c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c60cba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c613e80> False <keras.layers.core.Activation object at 0x7f5b6c61c5f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c5c8dd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c5cf320> False <keras.layers.core.Activation object at 0x7f5b6c5d7a90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c57fc88> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c5897b8> False <keras.layers.merge.Add object at 0x7f5b6c58ff28> False <keras.layers.core.Activation object at 0x7f5b6c5adc88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c542ba8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c547e80> False <keras.layers.core.Activation object at 0x7f5b6c5535f8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c57cdd8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c505320> False <keras.layers.core.Activation object at 0x7f5b6c50ca90> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c537c88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c4e8da0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4bd7b8> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4f8c88> False <keras.layers.merge.Add object at 0x7f5b6c47fba8> False <keras.layers.core.Activation object at 0x7f5b6c49ee80> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c4b3d68> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c4b8cf8> False <keras.layers.core.Activation object at 0x7f5b6c4427f0> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c46e7f0> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c474518> False <keras.layers.core.Activation object at 0x7f5b6c3fec88> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c426e80> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c42a9b0> False <keras.layers.merge.Add object at 0x7f5b6c42aac8> False <keras.layers.core.Activation object at 0x7f5b6c3d2eb8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c3e8d68> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c3f1cf8> False <keras.layers.core.Activation object at 0x7f5b6c3f97b8> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c3a6f98> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c3aa4e0> False <keras.layers.core.Activation object at 0x7f5b6c3b6c50> False <keras.layers.convolutional.Conv2D object at 0x7f5b6c35de48> False <keras.layers.normalization.BatchNormalization object at 0x7f5b6c365978> False <keras.layers.merge.Add object at 0x7f5b6c365a90> False <keras.layers.core.Activation object at 0x7f5b6c30aeb8> False <keras.layers.pooling.AveragePooling2D object at 0x7f5b6c311fd0> False <keras.layers.core.Flatten object at 0x7f5b6c2c22b0> True <keras.layers.core.Dense object at 0x7f5b6c2c2c18> True <keras.layers.core.Dense object at 0x7f5b6c2c7b00> True <keras.layers.core.Dense object at 0x7f5b6c2d6780> True
We then compile the model using the compile function. This function expects three parameters: the optimizer, the loss function, and the metrics of performance. The optimizer is the stochastic gradient descent algorithm we are going to use. We use the binary_crossentropy loss function since we are doing a binary classification.
In [19]:
opt = SGD(lr=1e-3, momentum=0.9) model.compile(loss="binary_crossentropy", optimizer=opt,metrics=["accuracy"])
Early Stopping
So to overcome this problem the concept of Early Stoping is used.
In this technique, we can specify an arbitrarily large number of training epochs and stop training once the model performance stops improving on a hold out validation dataset. Keras supports the early stopping of training via a callback called EarlyStopping.
Below are various arguments in EarlyStopping.
- monitor – This allows us to specify the performance measure to monitor in order to end training.
- mode – It is used to specify whether the objective of the chosen metric is to increase maximize or to minimize.
- verbose – To discover the training epoch on which training was stopped, the “verbose” argument can be set to 1. Once stopped, the callback will print the epoch number.
- patience – The first sign of no further improvement may not be the best time to stop training. This is because the model may coast into a plateau of no improvement or even get slightly worse before getting much better. We can account for this by adding a delay to the trigger in terms of the number of epochs on which we would like to see no improvement. This can be done by setting the “patience” argument.
es=EarlyStopping(monitor='val_accuracy', mode='max', verbose=1, patience=20)
Model Check Point
The EarlyStopping callback will stop training once triggered, but the model at the end of training may not be the model with the best performance on the validation dataset.
An additional callback is required that will save the best model observed during training for later use. This is known as the ModelCheckpoint callback.
The ModelCheckpoint callback is flexible in the way it can be used, but in this case, we will use it only to save the best model observed during training as defined by a chosen performance measure on the validation dataset.
mc = ModelCheckpoint('/content/gdrive/My Drive/best_model.h5', monitor='val_accuracy', mode='
Training the Model
H = model.fit_generator(train_generator,validation_data=test_generator,epochs=100,verbose=1,callbacks=[mc,es])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead. Epoch 1/100 785/785 [==============================] - 12435s 16s/step - loss: 0.1190 - accuracy: 0.9516 - val_loss: 1.2386 - val_accuracy: 0.9394 Epoch 2/100 785/785 [==============================] - 360s 459ms/step - loss: 0.0804 - accuracy: 0.9678 - val_loss: 1.2215 - val_accuracy: 0.9601 Epoch 3/100 785/785 [==============================] - 363s 462ms/step - loss: 0.0697 - accuracy: 0.9725 - val_loss: 1.3172 - val_accuracy: 0.9601 Epoch 4/100 785/785 [==============================] - 364s 463ms/step - loss: 0.0628 - accuracy: 0.9750 - val_loss: 1.3791 - val_accuracy: 0.9617 Epoch 5/100 785/785 [==============================] - 365s 465ms/step - loss: 0.0564 - accuracy: 0.9775 - val_loss: 1.4478 - val_accuracy: 0.9569 Epoch 6/100 785/785 [==============================] - 364s 463ms/step - loss: 0.0544 - accuracy: 0.9793 - val_loss: 1.5138 - val_accuracy: 0.9569 Epoch 7/100 785/785 [==============================] - 366s 467ms/step - loss: 0.0511 - accuracy: 0.9806 - val_loss: 1.4760 - val_accuracy: 0.9537 Epoch 8/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0472 - accuracy: 0.9817 - val_loss: 1.5260 - val_accuracy: 0.9601 Epoch 9/100 785/785 [==============================] - 378s 481ms/step - loss: 0.0450 - accuracy: 0.9824 - val_loss: 1.5810 - val_accuracy: 0.9601 Epoch 10/100 785/785 [==============================] - 383s 488ms/step - loss: 0.0437 - accuracy: 0.9832 - val_loss: 1.5863 - val_accuracy: 0.9585 Epoch 11/100 785/785 [==============================] - 382s 487ms/step - loss: 0.0402 - accuracy: 0.9845 - val_loss: 1.5143 - val_accuracy: 0.9601 Epoch 12/100 785/785 [==============================] - 389s 496ms/step - loss: 0.0368 - accuracy: 0.9867 - val_loss: 1.5643 - val_accuracy: 0.9601 Epoch 13/100 785/785 [==============================] - 383s 488ms/step - loss: 0.0341 - accuracy: 0.9870 - val_loss: 1.5968 - val_accuracy: 0.9601 Epoch 14/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0332 - accuracy: 0.9878 - val_loss: 1.6893 - val_accuracy: 0.9601 Epoch 15/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0302 - accuracy: 0.9891 - val_loss: 1.8144 - val_accuracy: 0.9617 Epoch 16/100 785/785 [==============================] - 387s 494ms/step - loss: 0.0296 - accuracy: 0.9891 - val_loss: 1.6615 - val_accuracy: 0.9601 Epoch 17/100 785/785 [==============================] - 386s 492ms/step - loss: 0.0297 - accuracy: 0.9898 - val_loss: 1.8563 - val_accuracy: 0.9569 Epoch 18/100 785/785 [==============================] - 385s 490ms/step - loss: 0.0291 - accuracy: 0.9888 - val_loss: 1.9082 - val_accuracy: 0.9617 Epoch 19/100 785/785 [==============================] - 384s 490ms/step - loss: 0.0284 - accuracy: 0.9894 - val_loss: 1.8754 - val_accuracy: 0.9665 Epoch 20/100 785/785 [==============================] - 384s 489ms/step - loss: 0.0277 - accuracy: 0.9890 - val_loss: 1.8689 - val_accuracy: 0.9633 Epoch 21/100 785/785 [==============================] - 381s 486ms/step - loss: 0.0247 - accuracy: 0.9913 - val_loss: 1.8928 - val_accuracy: 0.9617 Epoch 22/100 785/785 [==============================] - 381s 485ms/step - loss: 0.0230 - accuracy: 0.9917 - val_loss: 1.9322 - val_accuracy: 0.9617 Epoch 23/100 785/785 [==============================] - 378s 481ms/step - loss: 0.0224 - accuracy: 0.9912 - val_loss: 1.8897 - val_accuracy: 0.9633 Epoch 24/100 785/785 [==============================] - 378s 482ms/step - loss: 0.0203 - accuracy: 0.9928 - val_loss: 2.0867 - val_accuracy: 0.9585 Epoch 25/100 785/785 [==============================] - 378s 482ms/step - loss: 0.0220 - accuracy: 0.9916 - val_loss: 2.0524 - val_accuracy: 0.9617 Epoch 26/100 785/785 [==============================] - 379s 483ms/step - loss: 0.0209 - accuracy: 0.9926 - val_loss: 1.8708 - val_accuracy: 0.9601 Epoch 27/100 785/785 [==============================] - 375s 478ms/step - loss: 0.0195 - accuracy: 0.9929 - val_loss: 1.9471 - val_accuracy: 0.9617 Epoch 28/100 785/785 [==============================] - 376s 480ms/step - loss: 0.0154 - accuracy: 0.9946 - val_loss: 2.0850 - val_accuracy: 0.9601 Epoch 29/100 785/785 [==============================] - 376s 479ms/step - loss: 0.0205 - accuracy: 0.9930 - val_loss: 2.0068 - val_accuracy: 0.9665 Epoch 30/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0173 - accuracy: 0.9936 - val_loss: 2.0252 - val_accuracy: 0.9617 Epoch 31/100 785/785 [==============================] - 379s 483ms/step - loss: 0.0203 - accuracy: 0.9930 - val_loss: 2.2049 - val_accuracy: 0.9585 Epoch 32/100 785/785 [==============================] - 375s 478ms/step - loss: 0.0158 - accuracy: 0.9939 - val_loss: 2.2168 - val_accuracy: 0.9617 Epoch 33/100 785/785 [==============================] - 377s 480ms/step - loss: 0.0150 - accuracy: 0.9949 - val_loss: 2.1689 - val_accuracy: 0.9649 Epoch 34/100 785/785 [==============================] - 382s 486ms/step - loss: 0.0155 - accuracy: 0.9945 - val_loss: 2.3065 - val_accuracy: 0.9649 Epoch 35/100 442/785 [===============>..............] - ETA: 2:42 - loss: 0.0158 - accuracy: 0.9945
model.load_weights("/content/gdrive/My Drive/best_model.h5")
Evaluating ResNet 50 model on test datasets
Let us now evaluate the performance of our model on the unseen testing data set. We can see an accuracy of 99%.
model.evaluate_generator(test_generator)
[0.0068423871206876, 0.9931576128793124]
Serialize the ResNet Model
It is the best practice of converting the model into JSON format to save it for the inference program in the future. So we will save our ResNet model as below –
model_json = model.to_json()
with open("/content/gdrive/My Drive/model.json","w") as json_file:
json_file.write(model_json)
Dogs Vs Cat Classification Inference Program
- We implemented the ResNet-50 model with Keras.
- We saved the best training weights of the model in a file for future use.
- We saved the model in JSON format for reusability.
- Load the model that we saved in JSON format earlier.
- Load the weight that we saved after training the model earlier.
- Compile the model.
- Load the image that we want to classify.
- Perform classification.
from keras.models import model_from_json
def predict_(image_path):
#Load the Model from Json File
json_file = open('/content/gdrive/My Drive/model.json', 'r')
model_json_c = json_file.read()
json_file.close()
model_c = model_from_json(model_json_c)
#Load the weights
model_c.load_weights("/content/gdrive/My Drive/best_model.h5")
#Compile the model
opt = SGD(lr=1e-4, momentum=0.9)
model_c.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])
#load the image you want to classify
image = cv2.imread(image_path)
image = cv2.resize(image, (224,224))
cv2_imshow(image)
#predict the image
preds = model_c.predict(np.expand_dims(image, axis=0))[0]
if preds==0:
print("Predicted Label:Cat")
else:
print("Predicted Label: Dog")
Perform Classification
predict_("/content/gdrive/My Drive/datasets/test/Dog/4.jpg")
predict_("/content/gdrive/My Drive/datasets/test/Cat/10.jpg")
predict_("/content/gdrive/My Drive/datasets/test/Cat/7.jpg")
Conclusion
Coming to the end of a long article, we hope you would now know how to implement ResNet 50 with Keras. We used the Dog vs Cat dataset, but you can use just any dataset for creating your own ResNet 50 model with Keras.
-
My name is Sachin Mohan, an undergraduate student of Computer Science and Engineering. My area of interest is ‘Artificial intelligence’ specifically Deep learning and Machine learning. I have attended various online and offline courses on Machine learning and Deep Learning from different national and international institutes My interest toward Machine Learning and deep Learning made me intern at ISRO and also I become the 1st Runner up in TCS EngiNX 2019 contest. I always love to share my knowledge and experience and my philosophy toward learning is "Learning by doing". So thats why I believe in education which have include both theoretical as well as practical knowledge.
View all posts
10 Responses
Hello,
How did you take the pre-trained model weights?
Click here to download the pre-trained weights. Thanks Tejaswi for pointing this out, same has been updated in the article as well.
When we are building our own Resnet-50 model, they why are we again taking pre-built model weights?
Sorry for late reply Tejaswi. This is because we are using Transfer Learning method to train a Resnet model with our dataset and labels from scratch. The standard Resnet model is trained on ImageNet data with 1000 image labels only. If you want to build a ResNet model from scratch that works for your own images and categories we are using Transfer Learning for that purpose which makes use of pretrained model without last layer to start with. Do notice we are taking the pretrained weights of ResNet without Top Layer (notop in filename) . We are then kind of appending our own classifier layer on the top and then train it with our own dataset (Transfer Learning).
Hello, I have tried this code and the prediction is going wrong. It is giving dog for both the cat and dog. What should be done in this case?
How many training data have you taken for each of dog and cat? Also the number of epochs and training accuracy ?
Hello, can you help us in making the code work and also in developing the user interface.
we are using skin dataset, from Kaggle – https://www.kaggle.com/c/siim-isic-melanoma-classification
Unfortunately Tejaswi we do not provide any such service to help in coding.
Hello
when running the model.fit In[21], its telling me that i need to compile using model.compile(optimizer, loss).
It was left by mistake. Line to compile model is added now please check.