Keras Model Training Functions – fit() vs fit_generator() vs train_on_batch()

Introduction

In this article, we will give you an overview of Keras functions fit(), fit_generator(), and train_on_batch() which are used to submit the Keras model for training purposes. We will understand their syntax, compare their way of working, and which one you should use in which scenario.

Keras Fit : fit()

For Tensorflow less than v2.1

The first function used for fitting the models is fit() which is the most common and preferred way of fitting the model when we are dealing with small or medium sized datasets.

Keras fit() function is ideal for implementation when –

  • The training dataset is manageable and can fit into RAM. If the data is so huge that it cannot be fit in the RAM, then you will have problems with the fit() function.
  • The dataset used for training is raw and is not augmented.

From Tensorflow v2.1 onwards –

Recently in TensorFlow v2.1 onwards, the fit() function is capable of working with generators and can be used for both working with large datasets and ImageDataGenerator() for data augmentation.

This means it can be used in place of fit_generator() function that we have discussed in the next section.

The following section shows the syntax of fit function with necessary parameters.

Syntax of Keras fit()

fit(object, x = NULL, y = NULL, batch_size = NULL, epochs = 10, verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), validation_split = 0, validation_data = NULL,  shuffle = TRUE, class_weight = NULL, sample_weight = NULL,  initial_epoch = 0, steps_per_epoch = NULL, validation_steps = NULL,  ...)

Parameters

  • object : This parameter takes in the model to be trained
  • X, Y : Here we provide the training data for the model. It can be a vector, array or a matrix.
  • Batch_size : The batch size helps in telling about number of samples to be considered per gradient.
  • Epochs : This parameter helps in knowing the model about number of iterations for training the model.
  • Verbose : Through this parameter, we can show the status of our model training. If passed as ‘0’ i.e. silent, there will be not updates. If passed as ‘1’, then a progress bar is displayed. Lastly, if passed as ‘2’, then one output line will be displayed for each epoch.
  • Shuffle : Here we can toggle between whether we want our data to be shuffled before each epoch or not.
  • steps_per_epoch : As the name suggests, here we can specify the number of steps performed after a single epoch is finished and the next epoch starts.

Fit Generator : fit_generator()

fit_generator() also lets you submit a model for training in Keras but it also addresses the shortcoming of the fit() function we discussed above. If we have a large dataset that cannot be loaded in the RAM at once, fit_generator() is the recommended way of working.

fit_generator() is also useful in those scenarios when you need to apply data augmentation

In this function, a python generator is used for loading the data into memory using batches of data during the training phase. It is also capable of applying multiple data augmentation techniques like rotation, flipping, resizing, and many more with the help of ImageDataGenerator() function which helps to avoid overfitting.

From Tensorflow v2.1 onwards –

fit_generator() has been deprecated and its functionality of working with generator has been included in the fit() function discussed above.

Now, we will also look at fit_generator() function’s syntax and parameters in Keras

Syntax of Keras fit_generator()

fit_generator(object, generator, steps_per_epoch, epochs = 1,verbose = getOption("keras.fit_verbose", default = 1), callbacks = NULL, view_metrics = getOption("keras.view_metrics", default = "auto"), class_weight = NULL, max_queue_size = 10, workers = 1, initial_epoch = 0)

Parameters Used

  • object : This parameter takes in the model to be trained
  • generator : The generator parameter generates the output as a list, this is the main parameter for this function.
  • steps_per_epoch : As the name suggests, here we can specify the number of steps performed after a single epoch is finished and the next epoch starts.
  • Epochs : This parameter helps in knowing the model number of iterations for training the model.
  • Verbose : Through this parameter, we can show the status of our model training. If passed as ‘0’ i.e. silent, there will be not updates. If passed as ‘1’, then a progress bar is displayed. Lastly, if passed as ‘2’, then one output line will be displayed for each epoch.
  • callbacks : This gives the information about number of callback functions applied for training the model.

The following syntax shows the function i.e. ImageDataGenerator which is used for data augmentation by Keras fit_generator function. It has several parameters that define the data augmentation specifications along with data loading.

# performing data argumentation by training image generator
dataAugmentaion = ImageDataGenerator(rotation_range = 50, zoom_range = 0.15, 
fill_mode = "nearest", shear_range = 0.27, horizontal_flip = False, 
width_shift_range = 0.7, height_shift_range = 0.5)

# training the model
model.fit_generator(dataAugmentaion.flow(trainX, trainY, batch_size = 32),
 validation_data = (testX, testY), steps_per_epoch = len(trainX) // 32,
 epochs = 15)

Keras Train on batch : train_on_batch()

As the name suggests, the train_on_batch() function will execute a gradient update on one particular batch of training data. It then performs backpropagation, and after that, the function updates the model parameters.

However, in train_on_batch() function, we are not required to specify the batch size for training purposes and have to write our own custom iterator for the purpose of training. It can also be used for data generation and data augmentation.

The following syntax shows how train_on_batch function is implemented.

Syntax of Keras train_on_batch()

train_on_batch(x, y, sample_weight=None, class_weight=None, reset_metrics=True)

Parameters Used

  • x : First set of training dataset
  • y : Second set of training dataset
  • sample_weight : The weight provided to the model for training purposes
  • class_weight : This is the input weight for each class.
  • reset_metrics : This tells the function whether metrics have to be reset after each epoch or not.

fit() vs fit_generator() vs train_on_batch()

  • fit() is preferable when your training data is of small to medium size that can be loaded in the memory at once. Otherwise, you will find memory issues while training huge data.
  • fit_generator() is useful when you have a large dataset that cannot be loaded into RAM and you want to use the generator for passing the data. It can also be used for data augmentation with ImageDataGenerator.
  • From TensorFlow v2.1 however, fit_generator() has been deprecated and its functionality has been combined with fit() function itself.
  • train_on_batch() is also similar to fit_generator()  and is useful for advanced users when you would like to code your own custom iterator to pass the data for training. However, most of your needs can be fulfilled either by fit_generator or fit function (in TensorFlow v2.1+)

Conclusion

In this Keras tutorial, we talked about fit(), train_on_batch(), and fit_generator() functions which are used for training Keras models. We discussed their functionality and in which scenarios they are used. Finally, we also understood the differences between them.

Reference Keras Documentation

 

  • Palash Sharma

    I am Palash Sharma, an undergraduate student who loves to explore and garner in-depth knowledge in the fields like Artificial Intelligence and Machine Learning. I am captivated by the wonders these fields have produced with their novel implementations. With this, I have a desire to share my knowledge with others in all my capacity.

    View all posts

Follow Us

Leave a Reply

Your email address will not be published. Required fields are marked *