VOOZH about

URL: https://www.geeksforgeeks.org/deep-learning/training-of-recurrent-neural-networks-rnn-in-tensorflow/

⇱ Training of Recurrent Neural Networks (RNN) in TensorFlow - GeeksforGeeks


  • Courses
  • Tutorials
  • Interview Prep

Training of Recurrent Neural Networks (RNN) in TensorFlow

Last Updated : 15 May, 2026

Recurrent Neural Networks (RNNs) are neural networks designed to process sequential data by maintaining hidden states that store information from previous steps. In this implementation, TensorFlow is used to build and train an RNN model for sequence learning tasks.

Implementation

1. Importing Libraries

We will be importing Pandas, NumPy, Matplotlib, Seaborn, TensorFlow, Keras, NLTK and Scikit-learn for implementation.

2. Loading the Dataset

The dataset is loaded using pd.read_csv() and cleaned by removing rows with null values in the Class Name column.

  • Loads dataset using Pandas
  • Displays first 7 rows using data.head(7)
  • Removes null values from Class Name column

Output:

👁 First five rows of the dataset
First five rows of the dataset

3. Performing Exploratory Data Analysis

EDA helps understand the distribution and patterns in the dataset before building the model using different visualization techniques.

Count Plot of Class Name Distribution

sns.countplot() is used to visualize the count of each category in the Class Name column. The x-axis labels are rotated using plt.xticks(rotation=90) for better readability.

Output:

👁 eda1
Countplot

Count Plot of Rating and Recommendation Distribution

A figure of size 12×5 is created using plt.subplots() to visualize the distribution of ratings and recommendation indicators.

Output:

👁 Countplot for the Rating and Recommended IND category
Countplot for the Rating and Recommended IND category

Histogram of Age Distribution

A histogram is created using px.histogram() to visualize the frequency distribution of age. The plot also includes a box plot to show spread and outliers.

Output:

👁 Training of Recurrent Neural Networks (RNN) in TensorFlow
Histogram of Age Distribution

Interpretation of Age Distribution Plot

The histogram shows age distribution for recommended and non-recommended individuals, while the box plots display the spread and outliers for each group.

  • Green bars represent recommended individuals
  • Red bars represent non-recommended individuals
  • Box plots show spread and outliers of age values
  • Helps compare age distribution between groups
  • Can also be used to analyze age distribution with ratings

Output:

👁 Training of Recurrent Neural Networks (RNN) in TensorFlow
Interpretation of Age Distribution Plot

4. Prepare the Data to build Model

Since the dataset is NLP-based, text columns are used as features and the Rating column is used for sentiment analysis. To handle class imbalance, ratings above 3 are converted to 1 (positive) and ratings below 3 are converted to 0 (negative).

  • Uses text columns as input features
  • Uses Rating column for sentiment analysis
  • Handles imbalance in rating distribution
  • Converts ratings >3 to positive class (1)
  • Converts ratings <3 to negative class (0)

5. Text Preprocessing

Text preprocessing is performed to clean and standardize the text data before training the model. The text is converted to lowercase, lemmatized and cleaned by removing stopwords and punctuation.

  • Converts text to lowercase for consistency
  • Applies lemmatization to normalize words
  • Removes stopwords and punctuation
  • Reduces noise and improves text quality for training

6. Tokenization

Tokenization converts text data into numerical vectors that can be processed by the neural network. Keras provides a Tokenizer API to create word indices from the text data.

  • Converts text into numerical sequences
  • Uses Keras Tokenizer for preprocessing
  • num_words defines vocabulary size
  • OOV handles out-of-vocabulary words
  • fit_on_texts() is applied only on training data

7. Padding the Text Data

Padding is used to make all text sequences the same length before feeding them into the neural network. Extra zeros are added to shorter sequences, while longer sequences can be truncated if needed.

  • Makes all text sequences equal in length
  • Adds zeros to shorter sequences
  • Longer sequences can be truncated
  • Padding and tokenization are general NLP preprocessing techniques
  • Helps in efficient training of neural network models

8. Building a Recurrent Neural Network (RNN) in TensorFlow

After preprocessing the data, a Simple Recurrent Neural Network (SimpleRNN) is built for training. Before entering the RNN layer, the text data is passed through an Embedding layer to generate fixed-size word vectors.

  • Builds a SimpleRNN model using TensorFlow
  • Uses an Embedding layer before the RNN layer
  • Embedding converts words into dense vector representations
  • Fixed-size vectors help improve sequence learning

Output:

👁 Summary of the architecture of the model
Summary of the architecture of the model

9. Training the Model

After building the model, it is compiled using an optimizer, loss function and evaluation metric. The model is then trained on the preprocessed training data for multiple epochs.

  • Compiles model using optimizer, loss function and evaluation metric
  • Trains the model on train_pad data
  • Uses y_train as target labels
  • Runs training for 5 epochs to evaluate accuracy

Output:

👁 training
Training the Model

Download full code from here

Comment