VOOZH about

URL: https://www.analyticsvidhya.com/blog/2021/06/image-recognition-using-pytorch-lightning/

⇱ Image Recognition Using Pytorch Lightning - Analytics Vidhya


India's Most Futuristic AI Conference Is Back – Bigger, Sharper, Bolder

  • d
  • :
  • h
  • :
  • m
  • :
  • s

Reading list

Image Recognition Using Pytorch Lightning

keegan Last Updated : 20 Jul, 2021
5 min read

This article was published as a part of the Data Science Blogathon

About

Pytorch – lightning

Beginners often get intimidated by the amount of coding required for Deep Learning. This can often be due to complicated code and poor documentation for this reason even veterans in data science often have trouble understanding code. This is why I recommend Pytorch Lightning an open-source library that inherits Pytorch. Pytorch lightning automates a lot of the coding that comes with deep learning and neural networks so you can focus on model building. Pytorch-lightning also helps in writing cleaner code the is easily reproducible. For more information Check the official Pytorch -Lightning Website.

The Data

The dataset we will use here is the Yoga-poses dataset available on Kaggle. This dataset has already been structured in a way that will make building the model easier. The dataset has two main folders “Train” and “Test” that each contains 5 sub-folders the 5 sub-folders contain Images and the class of each Image is the name of the 5 sub-folders. The Dataset is small compared to other image datasets so we will be using data augmentation for the pre-processing. I’d recommend running this on a remote notebook like Kaggle notebooks as it can be computationally expensive to run any Image recognition model on a Local notebook. Now let’s get coding.

The Model

Prerequisites

Before we start coding if you want to follow along, you will need to install Pytorch -Lightning in case you’re running on a local environment all notebooks running on Kaggle or Colab should Already have it installed.

To install on a local python environment

pip install pytorch-lightning

To install it on a local conda environment

conda install -c conda-forge pytorch-lightning

I’d also recommend installing torchvision and cv2 to easily pre-process Image Data

Dependencies

Let’s run all the imports we will require to get started.

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import pandas as pd
import torchvision
import cv2
import torchvision.transforms as transforms
import os
from random import randint
!jupyter nbextension enable --py widgetsnbextension

If you cannot run this cell, one or more libraries haven’t been installed on the environment.

Preparing the Data

# declaring the path of the train and test folders
train_path = "../input/yoga-poses-dataset/DATASET/TRAIN"
test_path = "../input/yoga-poses-dataset/DATASET/TEST"
classes_dir_data = os.listdir(base_path)
num_of_classes = len(classes_dir_data)
print("Total Number of Classes :" , num_of_classes)
num = 0
classes_dict = {}
num_dict = {}
for c in classes_dir_data:
 classes_dict[c] = num
 num_dict[num] = c
 num = num +1
"""
num_dict contains a dictionary of the classes numerically and it's corresponding classes.
classes_dict contains a dictionary of the classes and the coresponding values numerically.
"""
num_of_classes = len(classes_dir_data)
classes_dict

output:-

The Image Dataset

#creating the dataset

#dataset

class Image_Dataset(Dataset):

 def __init__(self,classes,image_base_dir,transform = None, target_transform = None):

 """

 classes:The classes in the dataset

 image_base_dir:The directory of the folders containing the images

 transform:The trasformations for the Images

 Target_transform:The trasformations for the target

 """

 self.img_labels = classes

 self.imge_base_dir = image_base_dir

 self.transform = transform

 self.target_transform = target_transform

 def __len__(self):

 return len(self.img_labels)

 def __getitem__(self,idx):

 img_dir_list = os.listdir(os.path.join(self.imge_base_dir,self.img_labels[idx]))

 image_path = img_dir_list[randint(0,len(img_dir_list)-1)]

 #print(image_path)

 image_path = os.path.join(self.imge_base_dir,self.img_labels[idx],image_path)

 image = cv2.imread(image_path)

 if self.transform:

 image = self.transform(image)

 if self.transform:

 label = self.target_transform(self.img_labels[idx])

 return image,label

Transformers

All the transformations that will be run on this dataset. show the minimum transformations required to pass the data to the model using it we can quickly make a pipeline.

basic_transformations = transforms.Compose([
 transforms.ToPILImage(),
 transforms.Resize((size,size)),
 transforms.Grayscale(1),
 transforms.ToTensor()])
training_transformations = transforms.Compose([
 transforms.ToPILImage(),
 transforms.Resize((size,size)),
 transforms.RandomRotation(degrees = 45),
 transforms.RandomHorizontalFlip(p = 0.005),
 transforms.Grayscale(1),
 transforms.ToTensor()
])
def target_transformations(x):
 return torch.tensor(classes_dict.get(x))

Data Module

This Pytorch Lightning module will make passing values to the model easier for us

class YogaDataModule(pl.LightningDataModule):

 def __init__(self):

 super().__init__() 

 def prepare_data(self):

 self.train = Image_Dataset(classes_dir_data,train_path,training_transformations,target_transformations)

 self.valid = Image_Dataset(classes_dir_data,test_path,basic_transformations,target_transformations)

 self.test = Image_Dataset(classes_dir_data,test_path,basic_transformations,target_transformations)

 def train_dataloader(self):

 return DataLoader(self.train,batch_size = 64,shuffle = True)

 def val_dataloader(self): 

 return DataLoader(self.valid,batch_size = 64,shuffle = True)

 def test_dataloader(self):

 return DataLoader(self.test,batch_size = 64,shuffle = True)

Model

All the convolutions in the model retain the original input dimensions. The training_step and validation_step will handle the training and validation of the data. On each epoch, the model will return the best model. If you want to measure the metrics just call self.log() and the metrics will be saved on your preferred logger(be careful while using the logger on each step consumes memory). For more information about convolutions, I’d recommend checking out deep lizard’s free course on Deep Learning

class YogaModel(LightningModule):

 def __init__(self):

 super().__init__()

 """

 The convolutions are arranged in such a way that the image maintain the x and y dimensions. only the channels change

 """

 self.layer_1 = nn.Conv2d(in_channels = 1,out_channels = 3,kernel_size = (3,3),padding = (1,1),stride = (1,1))

 self.layer_2 = nn.Conv2d(in_channels = 3,out_channels = 6,kernel_size = (3,3),padding = (1,1),stride = (1,1))

 self.layer_3 = nn.Conv2d(in_channels = 6,out_channels = 12,kernel_size = (3,3),padding = (1,1),stride = (1,1))

 self.pool = nn.MaxPool2d(kernel_size = (3,3),padding = (1,1),stride = (1,1))

 self.layer_5 = nn.Linear(12*50*50,1000)#the input dimensions are (Number of dimensions * height * width)

 self.layer_6 = nn.Linear(1000,100)

 self.layer_7 = nn.Linear(100,50)

 self.layer_8 = nn.Linear(50,10)

 self.layer_9 = nn.Linear(10,5)

 def forward(self,x):

 """

 x is the input data

 """

 x = self.layer_1(x)

 x = self.pool(x)

 x = self.layer_2(x)

 x = self.pool(x)

 x = self.layer_3(x)

 x = self.pool(x)

 x = x.view(x.size(0),-1)

 print(x.size())

 x = self.layer_5(x)

 x = self.layer_6(x)

 x = self.layer_7(x)

 x = self.layer_8(x)

 x = self.layer_9(x)

 return x

 def configure_optimizers(self):

 optimizer = torch.optim.Adam(self.parameters(),lr = 1e-7)

 return optimizer

 """

 The Pytorch-Lightning module handles all the iterations of the epoch

 """

 def training_step(self,batch,batch_idx):

 x,y = batch

 y_pred = self(x)

 loss = F.cross_entropy(y_pred,y)

 return loss

 def validation_step(self,batch,batch_idx):

 x,y = batch

 y_pred = self(x)

 loss = F.cross_entropy(y_pred,y)

 return loss

 def test_step(self,batch,batch_idx):

 x,y = batch

 y_pred = self(x)

 loss = F.cross_entropy(y_pred,y)

 self.log("loss",loss)

 return loss

Training

Now we will finally train the model. Pytorch lightning makes using hardware easy just declare the number of CPU’s and GPU’s you want to use for the model and Lightning will Handle the rest

%%time # This cell 

from pytorch_lightning import Trainer

model = YogaModel()

module = YogaDataModule()

trainer = Trainer(max_epochs=1 , cpu = 1)#Don't go over 10000 - 100000 or it will take 5 - 53+ hours to iterate

trainer.fit(model,module)

output:-

Testing

The final cell will check the loss of the model on unseen data

trainer.test()

output:-

Notes

improving the model

  • If the loss of the train set is very high it means the model is under-fitting. To decrease the loss increase the number of max epochs in the model or learning rate. You can also add more non-linear layers to the model.
  • If the loss on the test set is high it means the model has over-fitted to the train set decreasing the number of epochs or increasing the learning rate or increasing the number of dropout layers should do the trick.
  • This notebook hasn’t used any callback methods. To check out the callback methods available in lightning Check out their official website
The media shown in this article are not owned by Analytics Vidhya and are used at the Author’s discretion.

Login to continue reading and enjoy expert-curated content.

Free Courses

Ensemble Learning and Ensemble Learning Techniques

Learn ensemble learning, its techniques, and how it works in this course!

Nano Course: Dreambooth-Stable Diffusion for Custom Images

Learn to create custom images with Dreambooth Stable Diffusion technology

Dimensionality Reduction for Machine Learning

Master key dimensionality reduction techniques for ML success!

Responses From Readers

Flagship Programs

GenAI Pinnacle Program| GenAI Pinnacle Plus Program| AI/ML BlackBelt Program| Agentic AI Pioneer Program

Free Courses

Generative AI| DeepSeek| OpenAI Agent SDK| LLM Applications using Prompt Engineering| DeepSeek from Scratch| Stability.AI| SSM & MAMBA| RAG Systems using LlamaIndex| Building LLMs for Code| Python| Microsoft Excel| Machine Learning| Deep Learning| Mastering Multimodal RAG| Introduction to Transformer Model| Bagging & Boosting| Loan Prediction| Time Series Forecasting| Tableau| Business Analytics| Vibe Coding in Windsurf| Model Deployment using FastAPI| Building Data Analyst AI Agent| Getting started with OpenAI o3-mini| Introduction to Transformers and Attention Mechanisms

Popular Categories

AI Agents| Generative AI| Prompt Engineering| Generative AI Application| News| Technical Guides| AI Tools| Interview Preparation| Research Papers| Success Stories| Quiz| Use Cases| Listicles

Generative AI Tools and Techniques

GANs| VAEs| Transformers| StyleGAN| Pix2Pix| Autoencoders| GPT| BERT| Word2Vec| LSTM| Attention Mechanisms| Diffusion Models| LLMs| SLMs| Encoder Decoder Models| Prompt Engineering| LangChain| LlamaIndex| RAG| Fine-tuning| LangChain AI Agent| Multimodal Models| RNNs| DCGAN| ProGAN| Text-to-Image Models| DDPM| Document Question Answering| Imagen| T5 (Text-to-Text Transfer Transformer)| Seq2seq Models| WaveNet| Attention Is All You Need (Transformer Architecture) | WindSurf| Cursor

Popular GenAI Models

Llama 4| Llama 3.1| GPT 4.5| GPT 4.1| GPT 4o| o3-mini| Sora| DeepSeek R1| DeepSeek V3| Janus Pro| Veo 2| Gemini 2.5 Pro| Gemini 2.0| Gemma 3| Claude Sonnet 3.7| Claude 3.5 Sonnet| Phi 4| Phi 3.5| Mistral Small 3.1| Mistral NeMo| Mistral-7b| Bedrock| Vertex AI| Qwen QwQ 32B| Qwen 2| Qwen 2.5 VL| Qwen Chat| Grok 3

AI Development Frameworks

n8n| LangChain| Agent SDK| A2A by Google| SmolAgents| LangGraph| CrewAI| Agno| LangFlow| AutoGen| LlamaIndex| Swarm| AutoGPT

Data Science Tools and Techniques

Python| R| SQL| Jupyter Notebooks| TensorFlow| Scikit-learn| PyTorch| Tableau| Apache Spark| Matplotlib| Seaborn| Pandas| Hadoop| Docker| Git| Keras| Apache Kafka| AWS| NLP| Random Forest| Computer Vision| Data Visualization| Data Exploration| Big Data| Common Machine Learning Algorithms| Machine Learning| Google Data Science Agent
👁 Av Logo White

Continue your learning for FREE

Forgot your password?
👁 Av Logo White

Enter OTP sent to

Edit

Wrong OTP.

Enter the OTP

Resend OTP

Resend OTP in 45s

👁 Popup Banner
👁 AI Popup Banner