Saturday, November 16, 2024
Google search engine
HomeLanguageskeras.fit() and keras.fit_generator()

keras.fit() and keras.fit_generator()

keras.fit() and keras.fit_generator() in Python are two separate deep learning libraries which can be used to train our machine learning and deep learning models. Both these functions can do the same task, but when to use which function is the main question.

Keras.fit()

Syntax:

fit(object, x = NULL, y = NULL, batch_size = NULL, epochs = 10,
  verbose = getOption("keras.fit_verbose", default = 1),
  callbacks = NULL, view_metrics = getOption("keras.view_metrics",
  default = "auto"), validation_split = 0, validation_data = NULL,
  shuffle = TRUE, class_weight = NULL, sample_weight = NULL,
  initial_epoch = 0, steps_per_epoch = NULL, validation_steps = NULL,
  ...)

Understanding few important arguments:

      
-> object : the model to train.      
-> X : our training data. Can be Vector, array or matrix      
-> Y : our training labels. Can be Vector, array or matrix       
-> Batch_size : it can take any integer value or NULL and by default, it will
be set to 32. It specifies no. of samples per gradient.      
-> Epochs : an integer and number of epochs we want to train our model for.      
-> Verbose : specifies verbosity mode(0 = silent, 1= progress bar, 2 = one
line per epoch).      
-> Shuffle : whether we want to shuffle our training data before each epoch.      
-> steps_per_epoch : it specifies the total number of steps taken before
one epoch has finished and started the next epoch. By default it values is set to NULL.

How to use Keras fit:

model.fit(Xtrain, Ytrain, batch_size = 32, epochs = 100)

Here we are first feeding the training data(Xtrain) and training labels(Ytrain). We then use Keras to allow our model to train for 100 epochs on a batch_size of 32.

When we call the .fit() function it makes assumptions:

  • The entire training set can fit into the Random Access Memory (RAM) of the computer.
  • Calling the model. fit method for a second time is not going to reinitialize our already trained weights, which means we can actually make consecutive calls to fit if we want to and then manage it properly.
  • There is no need for using the Keras generators(i.e no data argumentation)
  • Raw data is itself used for training our network and our raw data will only fit into the memory.


The Keras.fit_generator():

Syntax:

fit_generator(object, generator, steps_per_epoch, epochs = 1,
  verbose = getOption("keras.fit_verbose", default = 1),
  callbacks = NULL, view_metrics = getOption("keras.view_metrics",
  default = "auto"), validation_data = NULL, validation_steps = NULL,
  class_weight = NULL, max_queue_size = 10, workers = 1,
  initial_epoch = 0)

Understanding few important arguments:

      
-> object : the Keras Object model.
-> generator : a generator whose output must be a list of the form:
                      - (inputs, targets)    
                      - (input, targets, sample_weights)
a single output of the generator makes a single batch and hence all arrays in the list 
must be having the length equal to the size of the batch. The generator is expected 
to loop over its data infinite no. of times, it should never return or exit.
-> steps_per_epoch : it specifies the total number of steps taken from the generator
 as soon as one epoch is finished and next epoch has started. We can calculate the value
of steps_per_epoch as the total number of samples in your dataset divided by the batch size.
-> Epochs : an integer and number of epochs we want to train our model for.
-> Verbose : specifies verbosity mode(0 = silent, 1= progress bar, 2 = one line per epoch).
-> callbacks : a list of callback functions applied during the training of our model.
-> validation_data can be either:
                      - an inputs and targets list
                      - a generator
                      - an inputs, targets, and sample_weights list which can be used to evaluate
                        the loss and metrics for any model after any epoch has ended.
-> validation_steps :only if the validation_data is a generator then only this argument
can be used. It specifies the total number of steps taken from the generator before it is 
stopped at every epoch and its value is calculated as the total number of validation data points
in your dataset divided by the validation batch size.

How to use Keras fit_generator:

# performing data argumentation by training image generator
dataAugmentaion = ImageDataGenerator(rotation_range = 30, zoom_range = 0.20, 
fill_mode = "nearest", shear_range = 0.20, horizontal_flip = True, 
width_shift_range = 0.1, height_shift_range = 0.1)

# training the model
model.fit_generator(dataAugmentaion.flow(trainX, trainY, batch_size = 32),
 validation_data = (testX, testY), steps_per_epoch = len(trainX) // 32,
 epochs = 10)

Here we are training our network for 10 epochs along with the default batch size of 32.

For small and less complex datasets it is recommended to use keras.fit function whereas while dealing with real-world datasets it is not that simple because real-world datasets are huge in size and are much harder to fit into the computer memory.
It is more challenging to deal with those datasets and an important step to deal with those datasets is to perform data augmentation to avoid the overfitting of a model and also to increase the ability of our model to generalize.

Data Augmentation is a method of artificially creating a new dataset for training from the existing training dataset to improve the performance of deep learning neural networks with the amount of data available. It is a form of regularization which makes our model generalize better than before.
Here we have used a Keras ImageDataGenerator object to apply data augmentation for randomly translating, resizing, rotating, etc the images. Each new batch of our data is randomly adjusting according to the parameters supplied to ImageDataGenerator.

When we call the .fit_generator() function it makes assumptions:

  • Keras is first calling the generator function(dataAugmentaion)
  • Generator function(dataAugmentaion) provides a batch_size of 32 to our .fit_generator() function.
  • our .fit_generator() function first accepts a batch of the dataset, then performs backpropagation on it, and then updates the weights in our model.
  • For the number of epochs specified(10 in our case) the process is repeated.

Summary :
So, we have learned the difference between Keras.fit and Keras.fit_generator functions used to train a deep learning neural network
.fit is used when the entire training dataset can fit into the memory and no data augmentation is applied.
.fit_generator is used when either we have a huge dataset to fit into our memory or when data augmentation needs to be applied.

RELATED ARTICLES

Most Popular

Recent Comments