VOOZH about

URL: https://www.analyticsvidhya.com/blog/2021/08/quick-start-with-tensorflow-callbacks/

⇱ TensorFlow Callbacks | What are TensorFlow Callbacks


India's Most Futuristic AI Conference Is Back – Bigger, Sharper, Bolder

  • d
  • :
  • h
  • :
  • m
  • :
  • s

Reading list

Quick Start with Tensorflow Callbacks

Ashish Last Updated : 31 Aug, 2021
5 min read

This article was published as a part of the Data Science Blogathon

What are Tensorflow Callbacks?

Tensorflow callbacks are functions or blocks of code which are executed during a specific instant while training a Deep Learning Model.

 We all are familiar with the Training process of any Deep Learning model. With the models getting more complex and resource-intensive the training times also have significantly increased. So it’s usual for models to take many hours to train. In the usual workflow before training the model, we fix all the options and parameters like learning rate, optimizers, losses. etc and start the model training. Once the training process is started there is no way to pause the training in case we want to change some params. Also, in some cases when the model has been trained for several hours and we want to tweak some parameters at the later stages, it is impossible to do so. This is where TensorFlow callbacks come to the rescue.

How to use Callbacks

1. First define the callbacks
2. Pass the callbacks when calling the model.fit()

# Stop training if NaN is encountered
NanStop = TerminateOnNaN()
# Decrease lr by 10% 
LrValAccuracy = ReduceLROnPlateau(monitor='val_accuracy', patience=1, factor= 0.9, mode='max', verbose=0)
model.fit(X_train,y_train,
epochs=10,
validation_data=(X_test,y_test),
callbacks = [NanStop, LrValAccuracy])

Let us have a look at some of the most useful callbacks

EarlyStopping

When we are training our models, we usually take a look at the metrics in order to monitor how well the model is performing. Usually, if we see extremely high metrics, we can conclude that our model is overfitting and if our metrics are really low then we are underfitting.

In case if the metrics increase above a certain range we can stop the training to prevent overfitting. The EarlyStopping callback allows us to do exactly this.

early_stop_cb = tf.keras.callbacks.EarlyStopping(
 monitor='val_loss', min_delta=0, patience=0, verbose=0,
 mode='auto'
)
  • monitor: The metric you want to monitor while training
  • min_delta: The minimum amount of change in the metric you want to consider as an improvement over the previous epoch
  • patience: The number of epochs for which you wait for the metric to wait. Else, you stop the training.
  • verbose : 0: don’t print anything, 1: show a progress bar, 2: print only epoch number
  • mode :
  • “auto” – try to detect the behaviour automatically from the metrics are given
  • “min” – stop training if metrics stopped decreasing
  • “max” – stop training if metrics stopped increasing

LambdaCallback

This callback is used to call certain lambda functions at specific times during the training process.
tf.keras.callbacks.LambdaCallback(
 on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None,
 on_train_begin=None, on_train_end=None, **kwargs
)

Here we can pass any lambda function we need to execute at the specified time. Let’s see what the arguments mean

  • on_epoch_begin:  call the function at the beginning of each epoch.
  • on_epoch_begin: call the function at the end of each epoch.
  • on_batch_begin:  call the function at the beginning of each batch.
  • on_batch_end: calls the function at the end of each batch.
  • on_train_begin: calls the function when the model starts training
  • on_train_end: calls when the model training is completed

print_batch_callback = LambdaCallback(
 on_batch_begin=lambda bat,log: print(bat),
 on_batch_begin=lambda bat,log: print(bat)
)

LearningRateScheduler

One of the most common tasks during the training process is to change the learning rates. Usually, as the model approaches the loss-minimization minima (best fit) we gradually start decreasing the learning rate to have better convergence.

Let’s see a simple example where we want to reduce our learning rate by 5% for every 3rd epoch. Here we need to pass in a function to the schedule argument which specifies the logic for change in learning rate.

def schedule(epoch,lr):
 if epoch % 3 == 0:
 lr = lr - (lr*.05)
 return lr
 return lr

# Decrease lr by 5% for every 3rd epoch
LrScheduler = tf.keras.callbacks.LearningRateScheduler(schedule,verbose=1)

ModelCheckpoint

We use this callback in order to save our Model at different frequencies. This allows us to save weights at intermediate steps so that if needed we can load weights later.

tf.keras.callbacks.ModelCheckpoint(
 filepath, monitor='val_loss', verbose=0, save_best_only=False,
 save_weights_only=False, mode='auto', save_freq='epoch'
)

file-path: the location where the mode
monitor: metric to be monitored
save_best_only: True: Save only the best model,  False: Save all the models when metric improves
mode: min, max, or auto
save_weights_only: False: save only model weights, True: Save both model weights and model architecture

For example, let’s see an example to save the model having the best accuracy

filePath = "models/Model1_weights.{epoch:02d}.hdf5"
model_checkpoint_callback = tf.keras.callbacksModelCheckpoint(
 filepath=filePath,
 save_weights_only=True,
 monitor='val_accuracy',
 mode='max')

Here we specify the file path using some template strings. {epoch:02d} is substituted by the epoch number when saving the model

ReduceLROnPlateau

This callback is used to reduce the training rate when the specific metric has stopped increasing and reached a plateau.

tf.keras.callbacks.ReduceLROnPlateau(
 monitor='val_loss', factor=0.1, patience=10, verbose=0,
 mode='auto', min_delta=0.0001, cooldown=0, min_lr=0, **kwargs
)

factor: the factor by which LR is reduced. New learning rate = old_learning_rate * factor
min_delta: minimum change needed to be considered as an improvement
cooldown: number of epochs to wait until the LR is reduced
min_lr: a minimum value below which the Learning rate cant go

TerminateOnNaN

This callback stops the training process when any loss becomes NaN
tf.keras.callbacks.TerminateOnNaN()

Tensorboard

Tensorboard allows us to display information regarding the training process like Metrics, Training graphs, Activation function histograms, and other distribution of gradients. To use tensorboard we first need to set up a log_dir where the tensorboard files get saved to.

log_dir="logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=True)
  • log_dir: directory to which the files are saved
  • histogram_freq: epochs frequency for which the histogram and gradient maps are computed
  • write_graph: whether we need to display and visualize graphs in the tensorboard

Image 1 (link below)

Write your own Callbacks

Apart from the inbuilt callbacks, we can define and use our own callbacks for different purposes. For example, let us say we want to define our own metric which gets calculated at the end of each epoch.

# Monitor MicroF1 and AUC Score
class Metrics_Callback(tf.keras.callbacks.Callback):
 def __init__(self,x_val,y_val):
 self.x_val = x_val
 self.y_val = y_val
 def on_train_begin(self, logs={}):
 self.history = {"auc_score":[],"micro_f1":[]}
 def on_epoch_end(self, epoch, logs={}):
 auc_score = roc_auc_score(self.y_val, model.predict_proba(self.x_val))
 y_true = [0 if x[0]==1.0 else 1 for x in self.y_val]
 f1_s = f1_score(y_true,self.model.predict_classes(self.x_val), average='micro')
 self.history["auc_score"].append(auc_score)
 self.history["micro_f1"].append(f1_s)
Metrics = Metrics_Callback(X_test,y_test)

Here we want to calculate the F1 score and AUC score at the end of each epoch. in the __init__ method we read the data needed to calculate the scores. Then at the end of each epoch, we calculate the metrics in the on_epoch_end function. We can use the following methods to execute code at different times-

on_epoch_begin: called at the beginning of each epoch.
on_epoch_begin: called at the end of each epoch.
on_batch_begin: called at the beginning of each batch.
on_batch_end: called at the end of each batch.
on_train_begin: called when the model starts training
on_train_end: called when the model training is completed

Conclusion

These were a few commonly used and most popular callbacks. The official TensorFlow documentation: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback gives us in-detail information about various other callbacks and their related use cases.

Image Sources

  1. Image 1 – https://www.tensorflow.org/tensorboard/get_started
The media shown in this article are not owned by Analytics Vidhya and are used at the Author’s discretion.

A Data Scientist currently exploring the fascinating world of Data Science. Well versed with Machine Learning and Deep Learning with an inclination towards NLP and Computer Vision applications. Apart from data science, my prior experiences includes working as a Python Developer for a brief period.

Login to continue reading and enjoy expert-curated content.

Free Courses

Generative AI - A Way of Life

Explore Generative AI for beginners: create text and images, use top AI tools, learn practical skills, and ethics.

Getting Started with Large Language Models

Master Large Language Models (LLMs) with this course, offering clear guidance in NLP and model training made simple.

Building LLM Applications using Prompt Engineering

This free course guides you on building LLM apps, mastering prompt engineering, and developing chatbots with enterprise data.

Improving Real World RAG Systems: Key Challenges & Practical Solutions

Explore practical solutions, advanced retrieval strategies, and agentic RAG systems to improve context, relevance, and accuracy in AI-driven applications.

Microsoft Excel: Formulas & Functions

Master MS Excel for data analysis with key formulas, functions, and LookUp tools in this comprehensive course.

Responses From Readers

AI Tools

Wow, this article on Tensorflow Callbacks from Ashish Salaskar, as part of the Data Science Blogathon, was an absolute revelation for me! I was aware of the training process involved with Deep Learning models, but the concept of callbacks really adds another layer of complexity and control that wasn't previously clear to me. The idea that callbacks such as 'NanStop' or 'EarlyStopping' allow us to modify the training process dynamically to avoid overfitting or stop training at a certain metric level is fascinating. It's clear that there is a level of sophistication here that heightens the control and influence we have on the model's learning and development. The use of lambdas within the 'LambdaCallback' was especially intriguing, offering a means to engage with certain functions during specific moments of training. Could you clarify more on how and when you'd use this specific callback? Also, the 'ModelCheckpoint' function appears to be a really useful tool, enabling intermediate weights and model architecture to be stored. I wonder, is there a best practice regarding the frequency of checkpoints? I'm extremely appreciative of the section on how to create custom callbacks, such as expanding into metrics like F1 Scores and AUC scores. It just underlines the flexibility and freedom that tensorflow callbacks afford! So much appreciation for sharing this wonderful article! I'm eager to explore more about each callback and how it can improve and optimize my own Deep Learning models. I would certainly be diving deeper into the TensorFlow documentation. Thank you for inspiring this exciting new learning journey!

Flagship Programs

GenAI Pinnacle Program| GenAI Pinnacle Plus Program| AI/ML BlackBelt Program| Agentic AI Pioneer Program

Free Courses

Generative AI| DeepSeek| OpenAI Agent SDK| LLM Applications using Prompt Engineering| DeepSeek from Scratch| Stability.AI| SSM & MAMBA| RAG Systems using LlamaIndex| Building LLMs for Code| Python| Microsoft Excel| Machine Learning| Deep Learning| Mastering Multimodal RAG| Introduction to Transformer Model| Bagging & Boosting| Loan Prediction| Time Series Forecasting| Tableau| Business Analytics| Vibe Coding in Windsurf| Model Deployment using FastAPI| Building Data Analyst AI Agent| Getting started with OpenAI o3-mini| Introduction to Transformers and Attention Mechanisms

Popular Categories

AI Agents| Generative AI| Prompt Engineering| Generative AI Application| News| Technical Guides| AI Tools| Interview Preparation| Research Papers| Success Stories| Quiz| Use Cases| Listicles

Generative AI Tools and Techniques

GANs| VAEs| Transformers| StyleGAN| Pix2Pix| Autoencoders| GPT| BERT| Word2Vec| LSTM| Attention Mechanisms| Diffusion Models| LLMs| SLMs| Encoder Decoder Models| Prompt Engineering| LangChain| LlamaIndex| RAG| Fine-tuning| LangChain AI Agent| Multimodal Models| RNNs| DCGAN| ProGAN| Text-to-Image Models| DDPM| Document Question Answering| Imagen| T5 (Text-to-Text Transfer Transformer)| Seq2seq Models| WaveNet| Attention Is All You Need (Transformer Architecture) | WindSurf| Cursor

Popular GenAI Models

Llama 4| Llama 3.1| GPT 4.5| GPT 4.1| GPT 4o| o3-mini| Sora| DeepSeek R1| DeepSeek V3| Janus Pro| Veo 2| Gemini 2.5 Pro| Gemini 2.0| Gemma 3| Claude Sonnet 3.7| Claude 3.5 Sonnet| Phi 4| Phi 3.5| Mistral Small 3.1| Mistral NeMo| Mistral-7b| Bedrock| Vertex AI| Qwen QwQ 32B| Qwen 2| Qwen 2.5 VL| Qwen Chat| Grok 3

AI Development Frameworks

n8n| LangChain| Agent SDK| A2A by Google| SmolAgents| LangGraph| CrewAI| Agno| LangFlow| AutoGen| LlamaIndex| Swarm| AutoGPT

Data Science Tools and Techniques

Python| R| SQL| Jupyter Notebooks| TensorFlow| Scikit-learn| PyTorch| Tableau| Apache Spark| Matplotlib| Seaborn| Pandas| Hadoop| Docker| Git| Keras| Apache Kafka| AWS| NLP| Random Forest| Computer Vision| Data Visualization| Data Exploration| Big Data| Common Machine Learning Algorithms| Machine Learning| Google Data Science Agent
👁 Av Logo White

Continue your learning for FREE

Forgot your password?
👁 Av Logo White

Enter OTP sent to

Edit

Wrong OTP.

Enter the OTP

Resend OTP

Resend OTP in 45s

👁 Popup Banner
👁 AI Popup Banner