VOOZH about

URL: https://towardsdatascience.com/forgetting-in-deep-learning-4672e8843a7f/

⇱ Forgetting in Deep Learning | Towards Data Science


Forgetting in Deep Learning

Team member: Qiang Fei, Yingsi Jian, Mingyue Wei, Shuyuan Xiao

12 min read

A study of techniques that are related to catastrophic forgetting in deep neural networks.

👁 In partnership with Google
In partnership with Google

Authors: Qiang Fei, Yingsi Jian, Mingyue Wei, Shuyuan Xiao

You may also checkout the contents through our project poster and video.

Disclaimer: The views expressed in this blog are those of the authors and are not endorsed by Harvard or Google.


Problem Statement

Neural network models suffer from the phenomenon of catastrophic forgetting: a model can drastically lose its generalization ability on a task after being trained on a new task. This usually means a new task will likely override the weights that have been learned in the past (see Figure 1), and thus degrade the model performance for the past tasks. Without fixing this problem, a single neural network will not be able to adapt itself to a continuous learning scenario, because it forgets the existing information/knowledge when it learns new things.

👁 Figure 1: Demonstration of catastrophic forgetting¹. Source: Attention-Based Selective Plasticity (Kolouri et al., 2019)
Figure 1: Demonstration of catastrophic forgetting¹. Source: Attention-Based Selective Plasticity (Kolouri et al., 2019)

For realistic applications of deep learning, where continual learning can be crucial, catastrophic forgetting would need to be avoided. However, there is only limited study about catastrophic forgetting and its underlying causes. In this project, we will explore how commonly used deep learning methods mitigate or exacerbate the degree of forgetting (e.g. batch-norm, dropout, data augmentation, weight decay, etc.). Further, we would like to select one or several methods and try to learn about the cause of effects.

Literature Review

For forgetting measurement, the main approach is to revisit a task after training on later tasks and compare the accuracy before and after². With different task settings, Kemker et al. proposed a method specifically for incremental class learning³ (for each new section, data of a single class will be learned), and Arora et al. suggest using scaled accuracy when comparing different models⁴.

There have been studies of ways to reduce the forgetting effect, and some of the proposed or accepted methods include: adding dropout layers⁵, max pooling, decreasing number of layers, and decreasing learning rates.

Data

The dataset we use is CIFAR-10. This dataset consists of 60,000 color images in total, evenly divided into 10 classes. For each class, 5,000 images are in the train set and the remaining are in the test set. The images are of size 32×32. The 10 classes of the dataset are:

airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

Experiment Setup

We went with the approach of comparing before and after accuracy to measure forgetting effects. Our experiment workflow is shown in Figure 2.

👁 Figure 2: Experiment workflow. Image by Author
Figure 2: Experiment workflow. Image by Author

The whole dataset is first evenly split into task 1 and task 2, each consisting of 5 classes of images. We train the model consecutively on task 1 and task 2, and after that, we re-evaluate the model on task 1 data by only refitting the output layer.

The two accuracy scores were then compared.

Accuracy 1: Accuracy on task 1 test before task 2 training
Accuracy 2: Accuracy on task 2 test after task 2 training
Forgetting = Accuracy 1 - Accuracy 2

Task 1 and Task 2 Splits

We used 2 sets of splits throughout the experiments.

The first set was used during the initial modeling and exploration stage, which consists of 5 different splits, (split 0 to split 4), all are randomly divided.

We formalized the second set of splits after some experiments, which consists of one semantic split based on image similarity from human perspective and two random splits. This set was used for extended experiments of random shift and constant learning rate experiments. Following are the task 1 classes for each split:

  • Semantic split: airplane, automobile, bird, ship, truck
  • Random split 1: airplane, deer, frog, ship, truck
  • Random split 2: automobile, cat, deer, frog, ship

Baseline Model Structure: Customized CNN

👁 Figure 3: Baseline model structure. Image by Author
Figure 3: Baseline model structure. Image by Author

Figure 3 shows the structure of our model. It is a customized CNN with decent performance across various task splits and trials, achieving consistent Accuracy 1 around 90%.

Other than the input and output layer, the model consists of 4 intermediate convolutional blocks, each with 2 convolutional layers, 1 batch normalization layer, 1 max pooling layer, and 1 dropout layer with 0.2 dropout rate.

Experiment Overview

The main objective of our experiments was to confirm whether the proposed methods of mitigating forgetting effects could work. We had two major directions: the first consists of experiments with various data augmentation methods, and the second dealt with training settings, including learning rate, weight decay, and training epochs.

After initial exploration of these approaches, we specifically extended our experiments on random shifting in data augmentation, and learning rate in training settings.

Data Augmentation Experiments

Basic Data Augmentation

We experimented with some basic data augmentation methods:

  1. Horizontal/Vertical Flip
  2. Random Shift up to 10% and up to 20%
  3. Random Rotation up to 10°, 15°, 20°

These data augmentation methods were applied on both task 1 and task 2 training process. The test set did not involve any data augmentation. Their influences on forgetting comparing to the baseline are shown in Figure 4.

👁 Figure 4: Percentage change in forgetting after each data augmentation method was applied, compared to the baseline model. Different colors represent results in different data splits. Image by Author
Figure 4: Percentage change in forgetting after each data augmentation method was applied, compared to the baseline model. Different colors represent results in different data splits. Image by Author

There is some variation between different data splits from the results. For most of the data augmentation methods, we cannot conclude if they have any effect on forgetting, as in some splits they increased forgetting while in some they did not, and the change was also generally not large.

However, among all the methods, random shift were clearly increasing forgetting, and the influence seems to increase when we shift more. We confirmed that Accuracy 1 of the models with random shifts were comparable to that of baseline without data augmentation, and as results were consistent across all 5 task splits, we believe that random shift does increase forgetting. This is interesting and quite surprising for us to learn more because random shift is a popular method applied to CIFAR-10 data to improve model performance.

Advance Data Augmentation: Mixup

Apart from basic data augmentation methods, we also experimented with mixup. Mixup is a popular method that can improve classification accuracy (as shown in Figure 5). Instead of feeding the network with raw images, mixup refers to take 2 images and do a linear combination of them using λ:

y_new = λ × (1 - y1) + (1 - λ) × y2
x_new = λ × (1 - x1) + (1 - λ) × x2
👁 Figure 5: Percentage change in Accuracy 1 with different levels of mixup, compared to the baseline model. Image by Author
Figure 5: Percentage change in Accuracy 1 with different levels of mixup, compared to the baseline model. Image by Author

λ is a number between 0 and 1 and is drawn from Beta(𝛼, 𝛼) each time, with 𝛼 being a pre-defined hyper-parameter. Smaller 𝛼 creates less mixup effect, and mixup with large 𝛼 could lead to under-fitting. We experimented with 𝛼 values of [0.1, 0.2, 0.3, 0.4], all within the suggested range from Zhang et al.⁶

👁 Figure 6: Percentage change in forgetting with different levels of mixup, compared to the baseline model. Different colors represent results in different data splits. Image by Author
Figure 6: Percentage change in forgetting with different levels of mixup, compared to the baseline model. Different colors represent results in different data splits. Image by Author

Figure 6 shows results from mixup experiments. For most of the data splits and 𝛼 values, mixup did not mitigate forgetting. Impact of mixup on forgetting varied across different data splits as well, e.g. when 𝛼= 0.1, forgetting was decreased with two data splits but increased with the other.

More on Random Shift

After the initial experiments, we reached the following conclusions:

  1. Random shift increases forgetting under the current setting.
  2. Mix up improves prediction accuracy but does not mitigate forgetting.
  3. The effect on forgetting of most data augmentation methods depends on data splits, thus generalization is not an easy get.

Based on these, we performed extended experiments with random shift using the second set of task splits.

In addition to the previous experiments applying random shift to both task 1 and task 2, we also did experiments where random shift was applied only to task 1, the results of which were presented in dashed lines in Figure 7. As the images are of size 32×32, we controlled the number of pixels up to which the image might shift in each experiment, with random shift of 4 pixels being the common practice in the field. All experiments were repeated 9 times and the average is presented.

👁 Figure 7: Results of applying random shifts of different pixels. Image by Author
Figure 7: Results of applying random shifts of different pixels. Image by Author

In the third plot, we can see that when applying random shift to both task 1 and task 2, there is a clear trend across all data splits, that forgetting decreases when the images were randomly shifted for 1 pixel and then increases when shifted pixels increase.

Learning Rate Experiments

Learning rate controls the speed of neural network updating its weights during training. We did two sets of exploration experiments study the effect of learning rate on forgetting.

Same Learning Rate with Rate Decay for Both Tasks

We experimented with 5 initial learning rates [0.001, 0.005, 0.01, 0.05, 0.1] based on common practice. We applied piecewise learning rate decay which applies cascading decrease after every certain amount of iterations. The same learning rate and rate decay were applied to both Task 1 and Task 2 training. The experiment was run on one task split only.

👁 Figure 8: Accuracy 1 and forgetting with different initial learning rates. Image by Author
Figure 8: Accuracy 1 and forgetting with different initial learning rates. Image by Author

Figure 8 presents the results. Learning rates 0.05, 0.01, and 0.005 had similar effects on Accuracy 1, but given the similar Accuracy 1, they had different effects on forgetting. Learning rate 0.005 resulted in the least amount of forgetting.

Different Learning Rate with Rate Decay for Two Tasks

We experimented with 4 initial learning rates: [0.005, 0.01, 0.05, 0.1] and thus there were a total of 16 pairs of (task 1 learning rate, task 2 learning rate).

👁 Figure 9: Forgetting from learning rates with decay experiments in one task split. Image by Author
Figure 9: Forgetting from learning rates with decay experiments in one task split. Image by Author

Figure 9 shows the results of the experiments. However, either fixing initial learning rate on task 1 or task 2 did not give apparent trend. One observation is that the diagonal seems to have the least forgetting, which corresponds to cases when task 1 and task 2 have the same initial learning rates.

Constant Learning Rate

With the experiments in exploration, it is hard to draw more conclusions because there are too many hyper-parameters to select for each experiment setting, and it was hard to associate the forgetting effect with any single one of them. We then decided to focused on constant learning rate experiments without rate decay.

We experimented with learning rates [0.0005, 0.001, 0.005, 0.01, 0.05], same for task 1 and task 2. These were selected based on Accuracy 1, i.e. they all achieve comparable Accuracy with these learning rates.

👁 Figure 10: Forgetting from constant learning rates experiments in three tasks splits. Image by Author
Figure 10: Forgetting from constant learning rates experiments in three tasks splits. Image by Author

Figure 10 shows the forgetting effects in the three splits. Heat map is easier to discern trend when we need to fix value on one of the axes. From the plots we got the following conclusions:

  1. Holding task 1 learning rate constant, decreasing task 2 learning rate mitigates the forgetting
  2. Holding task 2 learning rate constant, increasing task 1 learning rate mitigates the forgetting

Task 2 Frozen Accuracy Tests

In order to investigate why increasing the learning rate on task 1 could decrease forgetting effects, we also experimented with task 2 frozen accuracy, which is the test accuracy on task 2 without training on task 2. The hypothesis we wanted to test is that larger task 1 learning rate leads to better features generated, which then lead to mitigated forgetting. However, our experiments did not show a trend or relationship between task 1 learning rate and task 2 frozen accuracy.

Train Epoch Experiments

We set up experiments by fixing task 2 training epochs to learn the effect of number of epochs for task 1 on forgetting. Since both tasks training converge at around 40 epochs, we set task 2 training epoch = 40 and experimented with task 1 training epoch = [40, 60, 80, 100].

👁 Figure 11: Results of training task 1 with different number of epochs, fixing task 2 training epoch = 40. Image by Author
Figure 11: Results of training task 1 with different number of epochs, fixing task 2 training epoch = 40. Image by Author

Figure 11 shows that Accuracy 1 generally increases when the number of epochs increases, which aligns with common notice of longer training leading to better performance. But in the meanwhile, the forgetting effect generally increases when the task 1 number of epochs increases.

Weight Decay Experiments

We also ran experiments to investigate the influence of weight decay on forgetting. We did 2 sets of experiments:

  1. adding weight decay to the Conv2D layers in the first convolutional block
  2. adding weight decay to the Conv2D layers in the last convolutional block

For each set of experiments, we applied the l2 regularizer to weights and bias with the same rate, and we tested rates: [0.01, 0.05, 0.1, 0.2].

Other than these experiments, we also performed a grid search with learning rates [0.0008, 0.001, 0.002] and decay rates [5e-5 ,1e-4, 5e-4, 1e-3]. We applied weight decay on all Conv2D layers for either task 1 or task 2.

Accuracy 1 all seem valid with these settings. However, the resultant forgetting did not show consistent trends between different data splits, and thus we could not reach a conclusion on the effect of weight decay on forgetting.

Conclusion and Future Direction

From our four sets of experiments, we find some interesting results related to data augmentation and learning rates.

Our next step would be to interpret these results. More specifically, for the data augmentation part, we want to know why random shifts and mixup have such relationships with forgetting. For learning rate, we would also like to explore why increasing task 1 learning rate mitigates forgetting. One hypothesis is that larger task 1 learning rate helps capture and generate better features, but we would need other control experiments to test it.


[1] Kolouri, S., Ketz, N., Zou, X., Krichmar, J., & Pilly, P. (2019). Attention-based selective plasticity.

[2] Ramasesh, V. V., Dyer, E., & Raghu, M. (2020). Anatomy of catastrophic forgetting: Hidden representations and task semantics. arXiv preprint arXiv:2007.07400.

[3] Kemker, R., McClure, M., Abitino, A., Hayes, T., & Kanan, C. (2017). Measuring catastrophic forgetting in neural networks. arXiv preprint arXiv:1708.02072.

[4] Arora, G., Rahimi, A., & Baldwin, T. (2019). Does an LSTM forget more than a CNN? An empirical study of catastrophic forgetting in NLP. In Proceedings of the The 17th Annual Workshop of the Australasian Language Technology Association (pp. 77–86).

[5] Goodfellow, I. J., Mirza, M., Xiao, D., Courville, A., & Bengio, Y. (2013). An empirical investigation of catastrophic forgetting in gradient-based neural networks. arXiv preprint arXiv:1312.6211.

[6] Zhang, H., Cisse, M., Dauphin, Y. N., & Lopez-Paz, D. (2017). mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412.


Written By

Mingyue Wei

Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles