VOOZH about

URL: https://www.analyticsvidhya.com/blog/2018/03/essentials-of-deep-learning-visualizing-convolutional-neural-networks/

⇱ CNN Visualization | Methods Of Visualization


India's Most Futuristic AI Conference Is Back – Bigger, Sharper, Bolder

  • d
  • :
  • h
  • :
  • m
  • :
  • s

Reading list

Essentials of Deep Learning: Visualizing Convolutional Neural Networks in Python

Faizan Shaikh Last Updated : 15 May, 2020
9 min read

Introduction

One of the most debated topics in deep learning is how to interpret and understand a trained model – particularly in the context of high risk industries like healthcare. The term β€œblack box” has often been associated with deep learning algorithms. How can we trust the results of a model if we can’t explain how it works? It’s a legitimate question.

Take the example of a deep learning model trained for detecting cancerous tumours. The model tells you that it is 99% sure that it has detected cancer – but it does not tell you why or how it made that decision.

Did it find an important clue in the MRI scan? Or was it just a smudge on the scan that was incorrectly detected as a tumour? This is a matter of life and death for the patient and doctors cannot afford to be wrong.

πŸ‘ Image

In this article, we will explore how to visualize a convolutional neural network (CNN), a deep learning architecture particularly used in most state-of-the-art image based applications. We will get to know the importance of visualizing a CNN model, and the methods to visualize them. We will also take a look at a use case that will help you understand the concept better.

Note: This article assumes that you know the basics of Deep Learning and have previously worked on image processing problems using CNN. Also, we will be using Keras as our deep learning library. If you want to brush up on the concepts, you can go through these articles first:

You can also enroll in this free course on CNN to learn about them in structured manner: Convolutional Neural Networks (CNN) from Scratch

Let’s get on with it!

Table of Contents

  • Importance of Visualizing a CNN model
  • Methods of Visualization
    1. Preliminary Methods
      • Plot Model Architecture
      • Visualize Filters
    2. Activation based Methods
      • Maximal Activation
      • Image Occlusion
    3. Gradient based Methods
      • Saliency Map
      • Gradient based Class Activation Map

Importance of Visualizing a CNN model

As we have seen in the cancerous tumour example above, it is absolutely crucial that we know what our model is doing – and how it’s making decisions on its predictions. Typically, the reasons listed below are the most important points for a deep learning practitioner to remember:

  1. Understanding how the model works
  2. Assistance in Hyperparameter tuning
  3. Finding out the failures of the model and getting an intuition of why they fail
  4. Explaining the decisions to a consumer / end-user or a business executive

Let us look at an example where visualizing a neural network model helped in understanding the follies and improving the performance (the below example has been sourced from: http://intelligence.org/files/AIPosNegFactor.pdf).

Once upon a time, the US Army wanted to use neural networks to automatically detect camouflaged enemy tanks. The researchers trained a neural net on 50 photos of camouflaged tanks in trees, and 50 photos of trees without tanks. Using standard techniques for supervised learning, the researchers trained the neural network to a weighting that correctly loaded the training setβ€”output β€œyes” for the 50 photos of camouflaged tanks, and output β€œno” for the 50 photos of forest.

This did not ensure, or even imply, that new examples would be classified correctly. The neural network might have β€œlearned” 100 special cases that would not generalize to any new problem. Wisely, the researchers had originally taken 200 photos, 100 photos of tanks and 100 photos of trees. They had used only 50 of each for the training set. The researchers ran the neural network on the remaining 100 photos, and without further training the neural network classified all remaining photos correctly. Success confirmed! The researchers handed the finished work to the Pentagon, which soon handed it back, complaining that in their own tests the neural network did no better than chance at discriminating photos.

πŸ‘ Image

It turned out that in the researchers’ dataset, photos of camouflaged tanks had been taken on cloudy days, while photos of plain forest had been taken on sunny days. The neural network had learned to distinguish cloudy days from sunny days, instead of distinguishing camouflaged tanks from an empty forest.

Methods of Visualizing a CNN model

Broadly the methods of Visualizing a CNN model can be categorized into three parts based on their internal workings

  • Preliminary methods – Simple methods which show us the overall structure of a trained model
  • Activation based methods – In these methods, we decipher the activations of the individual neurons or a group of neurons to get an intuition of what they are doing
  • Gradient based methods – These methods tend to manipulate the gradients that are formed from a forward and backward pass while training a model

We will look at each of them in detail in the sections below. Here we will be using keras as our library for building deep learning models and keras-vis for visualizing them. Make sure you have installed these in your system before going ahead.

NOTE: This article uses the dataset given in β€œIdentify the Digits” competition. To run the code mentioned below, you would have to download it in your system. Also, please perform the steps provided in this page before starting with the implementation below.

1. Preliminary Methods

1.1 Plotting model architecture

The simplest thing you can do is to print/plot the model. Here, you can also print the shapes of individual layers of neural network and the parameters in each layer.

In keras, you can implement it as below:

model.summary()
_________________________________________________________________
Layer (type) Output Shape Param # 
=================================================================
conv2d_1 (Conv2D) (None, 26, 26, 32) 320 
_________________________________________________________________
conv2d_2 (Conv2D) (None, 24, 24, 64) 18496 
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64) 0 
_________________________________________________________________
dropout_1 (Dropout) (None, 12, 12, 64) 0 
_________________________________________________________________
flatten_1 (Flatten) (None, 9216) 0 
_________________________________________________________________
dense_1 (Dense) (None, 128) 1179776 
_________________________________________________________________
dropout_2 (Dropout) (None, 128) 0 
_________________________________________________________________
preds (Dense) (None, 10) 1290 
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0

For a more creative and expressive way – you can draw a diagram of the architecture (hint – take a look at the keras.utils.vis_utils function).

πŸ‘ Image

1.2 Visualize filters

Another way is to plot the filters of a trained model, so that we can understand the behaviour of those filters. For example, the first filter of the first layer of the above model looks like:

top_layer = model.layers[0]
plt.imshow(top_layer.get_weights()[0][:, :, :, 0].squeeze(), cmap='gray')

πŸ‘ Image

Generally, we see that the low level filters work as edge detectors, and as we go higher, they tend to capture high level concepts like objects and faces.

πŸ‘ Image

Source : http://web.eecs.umich.edu/~honglak/cacm2011-researchHighlights-convDBN.pdf

2. Activation Maps

2.1 Maximal Activations

To see what our neural network is doing, we can apply the filters over an input image and then plot the output. This allows us to understand what sort of input patterns activate a particular filter. For example, there could be a face filter that activates when it gets the presence of a face in the image.

from vis.visualization import visualize_activation
from vis.utils import utils
from keras import activations

from matplotlib import pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (18, 6)

# Utility to search for layer index by name.
# Alternatively we can specify this as -1 since it corresponds to the last layer.
layer_idx = utils.find_layer_idx(model, 'preds')

# Swap softmax with linear
model.layers[layer_idx].activation = activations.linear
model = utils.apply_modifications(model)

# This is the output node we want to maximize.
filter_idx = 0
img = visualize_activation(model, layer_idx, filter_indices=filter_idx)
plt.imshow(img[..., 0])

πŸ‘ Image

We can transfer this idea to all the classes and check how each of them would look like.

PS: Run the script below to check it out.

for output_idx in np.arange(10):
 # Lets turn off verbose output this time to avoid clutter and just see the output.
 img = visualize_activation(model, layer_idx, filter_indices=output_idx, input_range=(0., 1.))
 plt.figure()
 plt.title('Networks perception of {}'.format(output_idx))
 plt.imshow(img[..., 0])

2.2 Image Occlusion

In an image classification problem, a natural question is if the model is truly identifying the location of the object in the image, or just using the surrounding context. We took a brief look at this in gradient based methods above. Occlusion based methods attempt to answer this question by systematically occluding different portions of the input image with a grey square, and monitoring the output of the classifier. The examples clearly show the model is localizing the objects within the scene, as the probability of the correct class drops significantly when the object is occluded.

πŸ‘ Image

To understand this concept, let us take a random image from our dataset and try to plot a heatmap of the image. This will give us an intuition of which parts of the image are important for that model in order to make a clear distinction of the actual class.

def iter_occlusion(image, size=8):
 # taken from https://www.kaggle.com/blargl/simple-occlusion-and-saliency-maps

 occlusion = np.full((size * 5, size * 5, 1), [0.5], np.float32)
 occlusion_center = np.full((size, size, 1), [0.5], np.float32)
 occlusion_padding = size * 2

 # print('padding...')
 image_padded = np.pad(image, ( \
 (occlusion_padding, occlusion_padding), (occlusion_padding, occlusion_padding), (0, 0) \
 ), 'constant', constant_values = 0.0)

 for y in range(occlusion_padding, image.shape[0] + occlusion_padding, size):

 for x in range(occlusion_padding, image.shape[1] + occlusion_padding, size):
 tmp = image_padded.copy()

 tmp[y - occlusion_padding:y + occlusion_center.shape[0] + occlusion_padding, \
 x - occlusion_padding:x + occlusion_center.shape[1] + occlusion_padding] \
 = occlusion

 tmp[y:y + occlusion_center.shape[0], x:x + occlusion_center.shape[1]] = occlusion_center

 yield x - occlusion_padding, y - occlusion_padding, \
 tmp[occlusion_padding:tmp.shape[0] - occlusion_padding, occlusion_padding:tmp.shape[1] - occlusion_padding]

i = 23 # for example
data = val_x[i]
correct_class = np.argmax(val_y[i])

# input tensor for model.predict
inp = data.reshape(1, 28, 28, 1)

# image data for matplotlib's imshow
img = data.reshape(28, 28)

# occlusion
img_size = img.shape[0]
occlusion_size = 4

print('occluding...')

heatmap = np.zeros((img_size, img_size), np.float32)
class_pixels = np.zeros((img_size, img_size), np.int16)

from collections import defaultdict
counters = defaultdict(int)

for n, (x, y, img_float) in enumerate(iter_occlusion(data, size=occlusion_size)):

 X = img_float.reshape(1, 28, 28, 1)
 out = model.predict(X)
 #print('#{}: {} @ {} (correct class: {})'.format(n, np.argmax(out), np.amax(out), out[0][correct_class]))
 #print('x {} - {} | y {} - {}'.format(x, x + occlusion_size, y, y + occlusion_size))

 heatmap[y:y + occlusion_size, x:x + occlusion_size] = out[0][correct_class]
 class_pixels[y:y + occlusion_size, x:x + occlusion_size] = np.argmax(out)
 counters[np.argmax(out)] += 1

πŸ‘ Image

3. Gradient Based Methods

3.1 Saliency Maps

As we saw in the example of tanks, how can we get to know which part does our model focuses on to get prediction? For this, we can use saliency maps. Saliency maps was first introduced in the paper: Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps.

The concept of using saliency maps is pretty straight-forward – we compute the gradient of the output category with respect to the input image. This should tell us how the output category value changes with respect to a small change in the input image pixels. All the positive values in the gradients tell us that a small change to that pixel will increase the output value. Hence, visualizing these gradients, which are the same shape as the image, should provide some intuition of attention.

Intuitively this method highlights the salient image regions that contribute the most towards the output.

class_idx = 0
indices = np.where(val_y[:, class_idx] == 1.)[0]

# pick some random input from here.
idx = indices[0]

# Lets sanity check the picked image.
from matplotlib import pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = (18, 6)

plt.imshow(val_x[idx][..., 0])


from vis.visualization import visualize_saliency
from vis.utils import utils
from keras import activations

# Utility to search for layer index by name. 
# Alternatively we can specify this as -1 since it corresponds to the last layer.
layer_idx = utils.find_layer_idx(model, 'preds')

# Swap softmax with linear
model.layers[layer_idx].activation = activations.linear
model = utils.apply_modifications(model)

grads = visualize_saliency(model, layer_idx, filter_indices=class_idx, seed_input=val_x[idx])
# Plot with 'jet' colormap to visualize as a heatmap.
plt.imshow(grads, cmap='jet')


# This corresponds to the Dense linear layer.
for class_idx in np.arange(10): 
 indices = np.where(val_y[:, class_idx] == 1.)[0]
 idx = indices[0]

 f, ax = plt.subplots(1, 4)
 ax[0].imshow(val_x[idx][..., 0])
 
 for i, modifier in enumerate([None, 'guided', 'relu']):
 grads = visualize_saliency(model, layer_idx, filter_indices=class_idx, 
 seed_input=val_x[idx], backprop_modifier=modifier)
 if modifier is None:
 modifier = 'vanilla'
 ax[i+1].set_title(modifier) 
 ax[i+1].imshow(grads, cmap='jet')
πŸ‘ Image

3.2 Gradient based Class Activations Maps

Class activation maps, or grad-CAM, is another way of visualizing what our model looks at while making predictions. Instead of using gradients with respect to the output, grad-CAM uses penultimate Convolutional layer output. This is done to utilize the spacial information that is being stored in the penultimate layer.

from vis.visualization import visualize_cam

# This corresponds to the Dense linear layer.
for class_idx in np.arange(10): 
 indices = np.where(val_y[:, class_idx] == 1.)[0]
 idx = indices[0]

f, ax = plt.subplots(1, 4)
 ax[0].imshow(val_x[idx][..., 0])
 
for i, modifier in enumerate([None, 'guided', 'relu']):
 grads = visualize_cam(model, layer_idx, filter_indices=class_idx, 
 seed_input=val_x[idx], backprop_modifier=modifier) 
 if modifier is None:
 modifier = 'vanilla'
 ax[i+1].set_title(modifier) 
 ax[i+1].imshow(grads, cmap='jet')

πŸ‘ Image

End Notes

In this article, we have covered how to visualize a CNN model, and why should you do it along with an example. It has wide ranging applications from helping in medical cases to solving logistical issues for the army.

I hope this will give you an intuition of how to build better models in your own deep learning applications.

If you have any ideas / suggestions regarding the topic, do let me know in the comments below!

Participate in our Hackathons, including the DataHack Premier League and Lord of the Machines!

Faizan is a Data Science enthusiast and a Deep learning rookie. A recent Comp. Sc. undergrad, he aims to utilize his skills to push the boundaries of AI research.

Login to continue reading and enjoy expert-curated content.

Free Courses

Ensemble Learning and Ensemble Learning Techniques

Learn ensemble learning, its techniques, and how it works in this course!

Dimensionality Reduction for Machine Learning

Master key dimensionality reduction techniques for ML success!

Responses From Readers

Amazing Article.. Thanks a lot Sir..

123 1
Faizan Shaikh

Thanks Aditya

123 456
Sunny Toms

Good article, like to communicate with you.

123 1
Faizan Shaikh

Thanks sunny! I am fairly active on AV's discussion portal. You can always ask me a question there

123 456

Thank you for your great article. Do you know any tools which could visualize 3D CNN model? Many thanks.

123 1
Faizan Shaikh

Hey! keras-vis library has support for 3D CNN visualization, but I haven't tried it out. You can check this issue on GitHub

123 456

Flagship Programs

GenAI Pinnacle Program| GenAI Pinnacle Plus Program| AI/ML BlackBelt Program| Agentic AI Pioneer Program

Free Courses

Generative AI| DeepSeek| OpenAI Agent SDK| LLM Applications using Prompt Engineering| DeepSeek from Scratch| Stability.AI| SSM & MAMBA| RAG Systems using LlamaIndex| Building LLMs for Code| Python| Microsoft Excel| Machine Learning| Deep Learning| Mastering Multimodal RAG| Introduction to Transformer Model| Bagging & Boosting| Loan Prediction| Time Series Forecasting| Tableau| Business Analytics| Vibe Coding in Windsurf| Model Deployment using FastAPI| Building Data Analyst AI Agent| Getting started with OpenAI o3-mini| Introduction to Transformers and Attention Mechanisms

Popular Categories

AI Agents| Generative AI| Prompt Engineering| Generative AI Application| News| Technical Guides| AI Tools| Interview Preparation| Research Papers| Success Stories| Quiz| Use Cases| Listicles

Generative AI Tools and Techniques

GANs| VAEs| Transformers| StyleGAN| Pix2Pix| Autoencoders| GPT| BERT| Word2Vec| LSTM| Attention Mechanisms| Diffusion Models| LLMs| SLMs| Encoder Decoder Models| Prompt Engineering| LangChain| LlamaIndex| RAG| Fine-tuning| LangChain AI Agent| Multimodal Models| RNNs| DCGAN| ProGAN| Text-to-Image Models| DDPM| Document Question Answering| Imagen| T5 (Text-to-Text Transfer Transformer)| Seq2seq Models| WaveNet| Attention Is All You Need (Transformer Architecture) | WindSurf| Cursor

Popular GenAI Models

Llama 4| Llama 3.1| GPT 4.5| GPT 4.1| GPT 4o| o3-mini| Sora| DeepSeek R1| DeepSeek V3| Janus Pro| Veo 2| Gemini 2.5 Pro| Gemini 2.0| Gemma 3| Claude Sonnet 3.7| Claude 3.5 Sonnet| Phi 4| Phi 3.5| Mistral Small 3.1| Mistral NeMo| Mistral-7b| Bedrock| Vertex AI| Qwen QwQ 32B| Qwen 2| Qwen 2.5 VL| Qwen Chat| Grok 3

AI Development Frameworks

n8n| LangChain| Agent SDK| A2A by Google| SmolAgents| LangGraph| CrewAI| Agno| LangFlow| AutoGen| LlamaIndex| Swarm| AutoGPT

Data Science Tools and Techniques

Python| R| SQL| Jupyter Notebooks| TensorFlow| Scikit-learn| PyTorch| Tableau| Apache Spark| Matplotlib| Seaborn| Pandas| Hadoop| Docker| Git| Keras| Apache Kafka| AWS| NLP| Random Forest| Computer Vision| Data Visualization| Data Exploration| Big Data| Common Machine Learning Algorithms| Machine Learning| Google Data Science Agent
πŸ‘ Av Logo White

Continue your learning for FREE

Forgot your password?
πŸ‘ Av Logo White

Enter OTP sent to

Edit

Wrong OTP.

Enter the OTP

Resend OTP

Resend OTP in 45s

πŸ‘ Popup Banner
πŸ‘ AI Popup Banner