Long Short Term Memory (LSTM) Networks using PyTorch
Last Updated : 9 Oct, 2025
Long Short-Term Memory (LSTM) networks are a special type of Recurrent Neural Network (RNN) designed to address the vanishing gradient problem, which makes it difficult for traditional RNNs to learn long-term dependencies in sequential data.
Input Gate: decides what new information should be stored.
Forget Gate: decides what information should be discarded.
Output Gate: decides what information to output at each step.
This structure allows LSTMs to remember useful information for long periods while ignoring irrelevant details. In this article, we will learn how to implement an LSTM in PyTorch for sequence prediction on synthetic sine wave data.
Long Short-Term Memory (LSTM) Networks using PyTorch
LSTMs are widely used for sequence modeling tasks because of their ability to capture long-term dependencies. PyTorch provides a clean and flexible API to build and train LSTM models. In PyTorch, the nn.LSTM module handles the recurrence logic, while the rest of the architecture (such as fully connected layers, dropout, etc.) can be customized as needed.
Key Components
1. Input Size: Number of features in the input sequence at each time step.
2. Hidden Size: Number of features in the hidden state.
3. Number of Layers: Stacking multiple LSTM layers deepens the model.
4. Batch First: If set to True, input/output tensors are provided as (batch, seq_len, features) instead of (seq_len, batch, features).
5. Outputs:
Output Sequence: Hidden states at each time step.
Hidden State: Final hidden state for all layers.
Cell State: Final memory cell state for all layers.
We first import the necessary libraries such as torch, numpy and matplotlib and create a sine wave dataset. The data is split into input sequences of length 10, where the model predicts the next value.
np.linspace(): generates evenly spaced points.
np.sin(): creates sine values.
create_sequences(): prepares input-output pairs.
torch.tensor(): converts NumPy arrays into PyTorch tensors.
Step 2: Define the LSTM Model
We define an LSTM model using PyTorchβs nn.Module.
nn.LSTM: processes sequential data.
nn.Linear: maps hidden state outputs to predictions.
forward(): runs the data through LSTM + Fully Connected layer.
Step 3: Initialize Model, Loss Function, and Optimizer