VOOZH about

URL: https://thenewstack.io/tutorial-train-a-deep-learning-model-in-pytorch-and-export-it-to-onnx/

⇱ Tutorial: Train a Deep Learning Model in PyTorch and Export It to ONNX - The New Stack


TNS
SUBSCRIBE
Join our community of software engineering leaders and aspirational developers. Always stay in-the-know by getting the most important news and exclusive content delivered fresh to your inbox to learn more about at-scale software development.
REQUIRED
It seems that you've previously unsubscribed from our newsletter in the past. Click the button below to open the re-subscribe form in a new tab. When you're done, simply close that tab and continue with this form to complete your subscription.
The New Stack does not sell your information or share it with unaffiliated third parties. By continuing, you agree to our Terms of Use and Privacy Policy.
Welcome and thank you for joining The New Stack community!
Please answer a few simple questions to help us deliver the news and resources you are interested in.
REQUIRED
REQUIRED
REQUIRED
REQUIRED
REQUIRED
Great to meet you!
Tell us a bit about your job so we can cover the topics you find most relevant.
REQUIRED
REQUIRED
REQUIRED
REQUIRED
REQUIRED
Welcome!

We’re so glad you’re here. You can expect all the best TNS content to arrive Monday through Friday to keep you on top of the news and at the top of your game.

What’s next?

Check your inbox for a confirmation email where you can adjust your preferences and even join additional groups.

Follow TNS on your favorite social media networks.

Become a TNS follower on LinkedIn.

Check out the latest featured and trending stories while you wait for your first TNS newsletter.

PREV
1 of 2
NEXT
VOXPOP
As a JavaScript developer, what non-React tools do you use most often?
Angular
0%
Astro
0%
Svelte
0%
Vue.js
0%
Other
0%
I only use React
0%
I don't use JavaScript
0%
Thanks for your opinion! Subscribe below to get the final results, published exclusively in our TNS Update newsletter:
NEW! Try Stackie AI
From clobbered drafts to real-time sync
Apr 14th 2026 10:00am, by David Moore
TypeScript 6.0 RC arrives as a bridge to a faster future
Mar 14th 2026 9:00am, by Darryl K. Taft
Mastra empowers web devs to build AI agents in TypeScript
Jan 28th 2026 11:00am, by Loraine Lawson
2020-07-17 08:51:57
Tutorial: Train a Deep Learning Model in PyTorch and Export It to ONNX
feature,tutorial,
Software Development

Tutorial: Train a Deep Learning Model in PyTorch and Export It to ONNX

In this tutorial, we will train a Convolutional Neural Network in PyTorch and convert it into an ONNX model. Once we have the model in ONNX format, we can import that into other frameworks such as TensorFlow for either inference and reusing the model through transfer learning.
Jul 17th, 2020 8:51am by Janakiram MSV
👁 Featued image for: Tutorial: Train a Deep Learning Model in PyTorch and Export It to ONNX
Feature image: “Taking in the Wheat Sheaves” via New Old Stock.
This post is the third in a series of introductory tutorials on the Open Neural Network Exchange (ONNX), an initiative from AWS, Microsoft, and Facebook to define a standard for interoperability across machine learning platforms. See: Part 1, Part 2.

In this tutorial, we will train a Convolutional Neural Network in PyTorch and convert it into an ONNX model. Once we have the model in ONNX format, we can import that into other frameworks such as TensorFlow for either inference and reusing the model through transfer learning.

Setting up the Environment

The only prerequisite for this tutorial is Python 3.x. Make sure it is installed on your machine.

Create a Python virtual environment that will be used for this and the next tutorial.

python3 -m virtualenv pyt2tf

source pyt2tf/bin/activate

Create a file, requirements.txt, with the below content that has the modules needed for the tutorial.

torch
torchvision
opencv-python
tensorflow==1.15
onnx
onnxruntime
onnx_tf

Note that we are using TensorFlow 1.x for this tutorial. You may see errors if you install any version of TensorFlow above 1.15.

Install the modules from the above file with pip.

pip install -r requirements.txt

Finally, create a directory to save the model.

mkdir output

Train a CNN with MNIST Dataset

Let’s start by importing the right modules needed for the program.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

We will then define the neural network with appropriate layers.

class Net(nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = nn.Conv2d(1, 20, 5, 1)
 self.conv2 = nn.Conv2d(20, 50, 5, 1)
 self.fc1 = nn.Linear(4*4*50, 500)
 self.fc2 = nn.Linear(500, 10)

 def forward(self, x):
 x = F.relu(self.conv1(x))
 x = F.max_pool2d(x, 2, 2)
 x = F.relu(self.conv2(x))
 x = F.max_pool2d(x, 2, 2)
 x = x.view(-1, 4*4*50)
 x = F.relu(self.fc1(x))
 x = self.fc2(x)
 return F.log_softmax(x, dim=1)

Create a method to train the PyTorch model.

def train(model, device, train_loader, optimizer, epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_loader):
 data, target = data.to(device), target.to(device)
 optimizer.zero_grad()
 output = model(data)
 loss = F.nll_loss(output, target)
 loss.backward()
 optimizer.step()
 if batch_idx % 100 == 0:
 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 epoch, batch_idx * len(data), len(train_loader.dataset),
 100. * batch_idx / len(train_loader), loss.item()))

The below method will test and evaluate the model:

def test(model, device, test_loader):
 model.eval()
 test_loss = 0
 correct = 0
 with torch.no_grad():
 for data, target in test_loader:
 data, target = data.to(device), target.to(device)
 output = model(data)
 test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
 pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
 correct += pred.eq(target.view_as(pred)).sum().item()

 test_loss /= len(test_loader.dataset)

 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
 test_loss, correct, len(test_loader.dataset),
 100. * correct / len(test_loader.dataset)))

With the network architecture, train, and test methods in place, let’s create the main method to create an instance of the neural network and train it with the MNIST dataset.

def main():

 device = "cpu"

 train_loader = torch.utils.data.DataLoader(
 datasets.MNIST('../data', train=True, download=True,
 transform=transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])),
 batch_size=64, shuffle=True)
 test_loader = torch.utils.data.DataLoader(
 datasets.MNIST('../data', train=False, transform=transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])),
 batch_size=1000, shuffle=True)
 model = Net().to(device)
 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

 for epoch in range(0, 10):
 train(model, device, train_loader, optimizer, epoch)
 test(model, device, test_loader)


 torch.save(model.state_dict(),"output/model.pt")

 
if __name__ == '__main__':
 main() 

Within the main method, we download the MNIST dataset, preprocess it, and train the model with 10 epochs.

If you are training the model on a beefy box with a powerful GPU, you can change the device variable and tweak the number of epochs to get better accuracy. But, for the MNIST dataset, you will hit ~98% accuracy with just 10 epochs running on the CPU.👁 Image
Below is the complete code to train the model in PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

class Net(nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = nn.Conv2d(1, 20, 5, 1)
 self.conv2 = nn.Conv2d(20, 50, 5, 1)
 self.fc1 = nn.Linear(4*4*50, 500)
 self.fc2 = nn.Linear(500, 10)

 def forward(self, x):
 x = F.relu(self.conv1(x))
 x = F.max_pool2d(x, 2, 2)
 x = F.relu(self.conv2(x))
 x = F.max_pool2d(x, 2, 2)
 x = x.view(-1, 4*4*50)
 x = F.relu(self.fc1(x))
 x = self.fc2(x)
 return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch):
 model.train()
 for batch_idx, (data, target) in enumerate(train_loader):
 data, target = data.to(device), target.to(device)
 optimizer.zero_grad()
 output = model(data)
 loss = F.nll_loss(output, target)
 loss.backward()
 optimizer.step()
 if batch_idx % 100 == 0:
 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
 epoch, batch_idx * len(data), len(train_loader.dataset),
 100. * batch_idx / len(train_loader), loss.item()))
 
def test(model, device, test_loader):
 model.eval()
 test_loss = 0
 correct = 0
 with torch.no_grad():
 for data, target in test_loader:
 data, target = data.to(device), target.to(device)
 output = model(data)
 test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
 pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
 correct += pred.eq(target.view_as(pred)).sum().item()

 test_loss /= len(test_loader.dataset)

 print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
 test_loss, correct, len(test_loader.dataset),
 100. * correct / len(test_loader.dataset)))

def main():

 device = "cpu"

 train_loader = torch.utils.data.DataLoader(
 datasets.MNIST('../data', train=True, download=True,
 transform=transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])),
 batch_size=64, shuffle=True)
 test_loader = torch.utils.data.DataLoader(
 datasets.MNIST('../data', train=False, transform=transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])),
 batch_size=1000, shuffle=True)
 model = Net().to(device)
 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

 for epoch in range(0, 10):
 train(model, device, train_loader, optimizer, epoch)
 test(model, device, test_loader)


 torch.save(model.state_dict(),"output/model.pt")

 
if __name__ == '__main__':
 main() 

Once the training is done, you will find the file, model.pt, in the output directory. This is the artifact we need to convert the model into ONNX format.

Exporting PyTorch Model to ONNX Format

PyTorch supports ONNX natively which means we can convert the model without using an additional module.

Let’s load the trained model from the previous step, create an input that matches the shape of the input tensor, and export the model to ONNX.

The neural network class is included in the code to ensure that the model architecture is accessible along with the input tensor shape.

from torch.autograd import Variable
import torch
import torch.nn.functional as F
import torch.nn as nn

class Net(nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = nn.Conv2d(1, 20, 5, 1)
 self.conv2 = nn.Conv2d(20, 50, 5, 1)
 self.fc1 = nn.Linear(4*4*50, 500)
 self.fc2 = nn.Linear(500, 10)

 def forward(self, x):
 x = F.relu(self.conv1(x))
 x = F.max_pool2d(x, 2, 2)
 x = F.relu(self.conv2(x))
 x = F.max_pool2d(x, 2, 2)
 x = x.view(-1, 4*4*50)
 x = F.relu(self.fc1(x))
 x = self.fc2(x)
 return F.log_softmax(x, dim=1)

trained_model = Net()
trained_model.load_state_dict(torch.load('output/model.pt'))
dummy_input = Variable(torch.randn(1, 1, 28, 28)) 
torch.onnx.export(trained_model, dummy_input, "output/model.onnx") 

Running the above code results in the creation of model.onnx file which contains the ONNX version of the deep learning model originally trained in PyTorch.

You can open this in the Netron tool to explore the layers and the architecture of the neural network.

👁 Image

In the next part of this tutorial, we will import the ONNX model into TensorFlow and use it for inference. Stay tuned.

Janakiram MSV’s Webinar series, “Machine Intelligence and Modern Infrastructure (MI2)” offers informative and insightful sessions covering cutting-edge technologies. Sign up for the upcoming MI2 webinar at http://mi2.live.

TRENDING STORIES
Janakiram MSV (Jani) is a practicing architect, research analyst, and advisor to Silicon Valley startups. He focuses on the convergence of modern infrastructure powered by cloud-native technology and machine intelligence driven by generative AI. Before becoming an entrepreneur, he spent...
Read more from Janakiram MSV
SHARE THIS STORY
TRENDING STORIES
SHARE THIS STORY
TRENDING STORIES
TNS DAILY NEWSLETTER Receive a free roundup of the most recent TNS articles in your inbox each day.
The New Stack does not sell your information or share it with unaffiliated third parties. By continuing, you agree to our Terms of Use and Privacy Policy.