On the Implicit Bias of Adam
Abstract
In previous literature, backward error analysis was used to find ordinary differential equations (ODEs) approximating the gradient descent trajectory. It was found that finite step sizes implicitly regularize solutions because terms appearing in the ODEs penalize the two-norm of the loss gradients. We prove that the existence of similar implicit regularization in RMSProp and Adam depends on their hyperparameters and the training stage, but with a different “norm” involved: the corresponding ODE terms either penalize the (perturbed) one-norm of the loss gradients or, conversely, impede its reduction (the latter case being typical). We also conduct numerical experiments and discuss how the proven facts can influence generalization.
m o \IfNoValueTF#2(‘#1)#2(‘#1) \NewDocumentCommand\newco\IfValueT#1 \NewDocumentCommand\newbigco\IfValueT#1 \NewDocumentCommand\newdo\IfValueT#1 \NewDocumentCommand\newbigdo\IfValueT#1
1 Introduction
Gradient descent (GD) can be seen as a numerical method solving the ordinary differential equation (ODE) , where is the loss function and is its gradient. Starting at , it creates a sequence of guesses , which lie close to the solution trajectory governed by the aforementioned ODE. Since the step size is finite, one could search for a modified differential equation such that is exactly zero, or at least closer to zero than , that is, all the guesses of the descent lie exactly on the new solution curve or closer compared to the original curve. This approach to analysing properties of a numerical method is sometimes called backward error analysis in the numerical integration literature (see Chapter IX in Ernst Hairer & Wanner (2006) and references therein).
Barrett & Dherin (2021) used this idea for full-batch GD and found that the modified loss function makes the trajectory of the solution to approximate the sequence one order of better than the original ODE, where is the Euclidean norm. In related work, Miyagawa (2022) obtained the correction term for full-batch GD up to any chosen order, also studying the global error (uniform in the iteration number) as opposed to the local (one-step) error.
The analysis was later extended to mini-batch GD in Smith et al. (2021). Assume that the training set is split into batches of size and there are batches per epoch (so the training set size is ). The cost function is rewritten as with mini-batch costs denoted . It was obtained in that work that after one epoch, the mean iterate of the algorithm, averaged over all possible shuffles of the batch indices, is close to the solution to , where the modified loss is given by .
Modified equations have also been derived for GD with heavy-ball momentum , where is the momentum parameter. In the full-batch setting, it turns out that for large enough it is close to the continuous trajectory solving
| (1) |
Versions of this general result were proven in Farazmand (2020), Kovachki & Stuart (2021), Ghosh et al. (2023) under different assumptions. The focus of the latter work is the closest to ours since they interpret the correction term as implicit regularization. Their main theorem also provides the analysis for the general mini-batch case.
In another recent work, Zhao et al. (2022) introduce a regularization term to the loss function as a way to ensure finding flatter minima, improving generalization. The only difference between their term and the first-order correction coming from backward error analysis (up to a coefficient) is that the norm is not squared and regularization is applied on a per-batch basis.
The application of backward error analysis for approximating the discrete dynamics of adaptive algorithms such as RMSProp (Tieleman et al., 2012) and Adam (Kingma & Ba, 2015) is currently missing in the literature. Barrett & Dherin (2021) note that “it would be interesting to use backward error analysis to calculate the modified loss and implicit regularization for other widely used optimizers such as momentum, Adam and RMSprop”. Smith et al. (2021) reiterate that they “anticipate that backward error analysis could also be used to clarify the role of finite learning rates in adaptive optimizers like Adam”. Ghosh et al. (2023) agree that “RMSProp … and Adam …, albeit being powerful alternatives to SGD with faster convergence rates, are far from well-understood in the aspect of implicit regularization”. In a similar context, in Appendix G to Miyagawa (2022) it is mentioned that “its [Adam’s] counter term and discretization error are open questions”.
This work fills the gap by conducting backward error analysis for (mini-batch, and full-batch as a special case) Adam and RMSProp. Our main contributions are listed below.
-
•
In Theorem 3.1, we provide a global second-order in continuous piecewise ODE approximation to Adam in the general mini-batch setting. (A similar result for RMSProp is moved to Appendix C.) For the full-batch special case, it was shown in prior work Ma et al. (2022) that the continuous-time limit of both these algorithms is a (perturbed by the numerical stability parameter ) signGD flow component-wise; we make this more precise by finding a linear in correction term on the right.
-
•
We analyze the full-batch case in the context of regularization (see the summary in Section 2). In contrast to the case of GD, where the two-norm of the loss gradient is implicitly penalized, Adam typically anti-penalizes the perturbed one-norm of the loss gradient (i. e., penalizes the negative norm), as specified in (5). Thus, the implicit bias of Adam that we identify serves as anti-regularization (except for the unusual case , large or very late at training).
-
•
We provide numerical evidence consistent with our theoretical results by training various vision models on CIFAR-10 using full-batch Adam. In particular, we observe that the stronger the implicit anti-regularization effect predicted by our theory, the worse the generalization. This pattern holds across different architectures: ResNets, simple convolutional neural networks (CNNs) and Vision Transformers. Thus, we propose a novel possible explanation for often-reported poor generalization of adaptive gradient algorithms. The code used for training the models is available at https://github.com/borshigida/implicit-bias-of-adam.
1.1 Related Work
Backward error analysis of first-order methods. We outlined the history of finding ODEs approximating different algorithms above in the introduction. Recently, there have been other applications of backward error analysis related to machine learning. Kunin et al. (2020) show that the approximating continuous-time trajectories satisfy conservation laws that are broken in discrete time. França et al. (2021) use backward error analysis while studying how to discretize continuous-time dynamical systems preserving stability and convergence rates. Rosca et al. (2021) find continuous-time approximations of discrete two-player differential games.
Approximating gradient methods by differential equation trajectories. Under the assumption that the hyperparameters of the Adam algorithm (see Definition 1.1) tend to 1 at a certain rate as , a first-order continuous ODE approximation to this algorithm was derived in Barakat & Bianchi (2021). On the other hand, if are kept fixed, Ma et al. (2022) prove that the trajectories of Adam and RMSProp are close to signGD dynamics, and investigate different training regimes of these algorithms empirically. SGD is approximated by stochastic differential equations and novel adaptive parameter adjustment policies are devised in Li et al. (2017). Malladi et al. (2022) derive stochastic differential equations that are order-1 weak approximations of RMSProp and Adam. We go in a different direction: instead of clarifying the previously obtained continuous ODE approximations by taking gradient noise into account, we take a deterministic approach but go one order of further. In particular, we keep , fixed (thus generalizing the analysis for SGD with momentum), whereas Malladi et al. (2022) take .
Connection with signGD. The connection of adaptive gradient methods with sign(S)GD is extensively discussed in Bernstein et al. (2018). Balles et al. (2020) study a version of signGD with an update proportional to as a special case of steepest descent, and discuss when sign-based methods are preferable to GD.
Implicit bias of first-order methods. Soudry et al. (2018) prove that GD trained to classify linearly separable data with logistic loss converges to the direction of the max-margin vector (the solution to the hard margin SVM). This result has been extended to different loss functions in Nacson et al. (2019b), to SGD in Nacson et al. (2019c), AdaGrad in Qian & Qian (2019), (S)GD with momentum, deterministic Adam and stochastic RMSProp in Wang et al. (2022), more generic optimization methods in Gunasekar et al. (2018a), to the nonseparable case in Ji & Telgarsky (2018b), Ji & Telgarsky (2019). This line of research has been generalized to studying implicit biases of linear networks (Ji & Telgarsky, 2018a; Gunasekar et al., 2018b), homogeneous neural networks (Ji & Telgarsky, 2020; Nacson et al., 2019a; Lyu & Li, 2019). Woodworth et al. (2020) study the gradient flow of a diagonal linear network with squared loss and show that large initializations lead to minimum two-norm solutions while small initializations lead to minimum one-norm solutions. Even et al. (2023) extend this work to the case of non-zero step sizes and mini-batch training. Wang et al. (2021) prove that Adam and RMSProp maximize the margin of homogeneous neural networks. Our perspective on the implicit bias is different since we are considering a generic loss function without any assumptions on the network architecture. Beneventano (2023) proves that in expectation over batch sampling the trajectory of SGD without replacement differs from that of SGD with replacement by an additional step on a regularizer. As opposed to the work on backward error analysis for SGD discussed above, they do not assume the largest eigenvalue of the hessian to be bounded.
Generalization of adaptive methods. Cohen et al. (2022) investigate the edge-of-stability regime of adaptive gradient algorithms and the effect of sharpness (the largest eigenvalue of the hessian) on generalization. Granziol (2020); Chen et al. (2021) observe that adaptive methods find sharper minima than SGD and Zhou et al. (2020); Xie et al. (2022) argue theoretically that it is the case. Jiang et al. (2022) introduce a statistic that measures the uniformity of the hessian diagonal and argue that adaptive gradient algorithms are biased towards making this statistic smaller. Keskar & Socher (2017) propose to improve generalization of adaptive methods by switching to SGD in the middle of training.
1.2 Notation
We denote the loss of the th minibatch as a function of the network parameters by , and in the full-batch setting we omit the index and write . means the gradient of , and with indices denotes partial derivatives, e. g. is a shortcut for . The norm notation without indices is the two-norm of a vector, is the one-norm and is the perturbed one-norm defined as . (Of course, if the perturbed one-norm is not really a norm, but taking makes it the one-norm.) For a real number the floor is the largest integer not exceeding .
To provide the names and notations for hyperparameters, we define the algorithm below.
Definition 1.1.
The Adam algorithm (Kingma & Ba, 2015) is an optimization algorithm with numerical stability hyperparameter , squared gradient momentum hyperparameter , gradient momentum hyperparameter , initialization , , and the following update rule: for each ,
Remark 1.2.
Note that the numerical stability hyperparameter , which is introduced in these algorithms to avoid division by zero, is inside the square root in our definition. This way we avoid division by zero in the derivative too: the first derivative of is bounded for . This is useful for our analysis. In Theorems B.4 and D.4, the original versions of RMSProp and Adam are also tackled, though with an additional assumption which requires that no component of the gradient can come very close to zero in the region of interest. This is true only for the initial period of learning (whereas Theorem 3.1 tackles the whole period). Practitioners do not seem to make a distinction between the version with inside vs. outside the square root: tutorials with both versions abound on machine learning related websites. Moreover, the popular Tensorflow and Optax variants of RMSProp have inside the square root. Empirically we also observed that moving inside or outside the square root does not change the behavior of Adam or RMSProp qualitatively.
2 Implicit Bias of Full-Batch Adam: an Informal Summary
We are ready to describe our theoretical result (Theorem 3.1 below) in the full-batch special case. Assume is the loss, whose partial derivatives up to the fourth order are bounded. Let be iterations of Adam as defined in Definition 1.1. We find an ODE whose solution trajectory is -close to , meaning that for any time horizon there is a constant such that for any step size we have (for between and ). The ODE is written the following way (up to terms that rapidly go to zero as grows): for the component number
| (2) |
with initial conditions for all , where the correction term is
| (3) |
Depending on hyperparameters and the training stage, the correction term can take two extreme forms listed below. The reality is in between, but typically much closer to the first case.
-
•
If is small compared to all components of , i. e. , which is usually the case during most of the training, then we can write
(4) For small , the perturbed one-norm is indistinguishable from the usual one-norm, and for it is penalized (in much the same way as the squared two-norm is implicitly penalized in the case of GD), but for the typical case its decrease is actually hindered by this term (so the bias is anti-regularization). The ODE in (2) approximately becomes
(5) -
•
If is large compared to all gradient components, i. e. (which may happen during the late learning stage, or if non-standard hyperparameter values are chosen), the fraction in (3) with in the numerator approaches one, the dependence on cancels out, and
(6) In other words, becomes up to an additive constant, giving
The form of the ODE in this case is
(7)
These two extreme cases are summarized in Table 1. In Figure 1, we use the one-dimensional () case to illustrate what kind of term is being implicitly penalized.
| “small” | “large” | |
|---|---|---|
| -penalized | -penalized | |
| -penalized | -penalized |
Usually, is chosen to be small, and during most of the training Adam is much better described by the first extreme case. It is clear from (5) that, if , the correction term provides the opposite of regularization, in contrast to (1). The larger compared to , the stronger the anti-regularization effect is.
This finding may partially explain why adaptive gradient methods have been reported to generalize worse than non-adaptive ones (Chen et al., 2018; Wilson et al., 2017), as it offers a previously unknown perspective on why they are biased towards “higher-curvature” regions and find “sharper” minima. Indeed, note that standard (non-adaptive) -sharpness at can be defined by for some radius . This or similar definitions have been considered often in literature, see, e. g., Andriushchenko et al. (2023), Foret et al. (2021). Replacing the difference of the losses with its first-order approximation under the maximum (Foret et al., 2021; Ghosh et al., 2023)
we see that Adam typically anti-penalizes the approximation of -sharpness. Although the connection between sharpness and generalization is not clear-cut (Andriushchenko et al., 2023), our empirical results (Section 5) are consistent with this theory.
This overview also applies to RMSProp by setting ; see Theorem C.4 for the formal result.
Example 2.1(Backward Error Analysis for GD with Heavy-ball Momentum).
Assume is large compared to all squared gradient components during the whole training process, so that the form of the ODE is approximated by (7). Since Adam with a large and after a certain number of iterations approximates SGD with heavy-ball momentum with step size , a linear step size change (and corresponding time change) gives exactly the equations in Theorem 4.1 of Ghosh et al. (2023). Taking (no momentum), we get the implicit regularization of GD from Barrett & Dherin (2021).
3 Main Result: ODE Approximating Mini-Batch Adam
We only make one assumption, which is standard in the literature: the loss for each mini-batch is 4 times continuously differentiable, and partial derivatives of up to order 4 are bounded, i. e. there is a positive constant such that for in the region of interest
| (8) |
Theorem 3.1.
Assume (8) holds. Let be iterations of Adam as defined in Definition 1.1, be the continuous solution to the piecewise ODE
| (9) |
for with the initial condition , where
Then, for any fixed positive time horizon there exists a constant (depending on , , , ) such that for any step size we have for .
Remark 3.2.
Derivation sketch.
The proof is in the appendix (this is Theorem E.4; see Appendix A for the overview of the appendix). To help the reader understand how the ODE (9) is obtained, apart from the full proof, we include an informal derivation in Appendix I, and provide an even briefer sketch of this derivation here.
Our goal is to find such a trajectory that, denoting , we have
| (10) |
Ignoring the terms of order higher than one in , we can take a first-order approximation for granted: with . The challenge is to make this more precise by finding an equality of the form
| (11) |
where is a known function. This is a numerical iteration to which standard backward error analysis (Chapter IX in Ernst Hairer & Wanner (2006)) can be applied.
Using the Taylor series, we can write
where in the last equality we just replaced with in the -term since it only affects higher-order terms. Doing this again for steps , , , and adding the resulting equations will give for
where we could safely ignore that is not bounded because of exponential averaging. Taking the square of this formal power series in , multiplying this square by and summing up over will give
which, using the expression for the inverse square root of a formal power series , gives us an expansion
A similar process provides an expansion for :
Inserting these two expansions into (10) leads to an expression for :
We are now ready to find an ODE for of the form whose discretization is (11). This is a task for standard backward error analysis: expand into . By Taylor expansion, we have
It is left to equate the terms before the corresponding powers of here and in (11), giving and . Omitting some algebra, the piecewise ODE (9) is derived. ∎
4 Illustration: Simple Bilinear Model
We now analyze the effect of the first-order term for Adam in the same model as Barrett & Dherin (2021) and Ghosh et al. (2023) have studied. Namely, assume the parameter is 2-dimensional, and the loss is given by . The loss is minimized on the hyperbola . We graph the trajectories of Adam in this case: the left part of Figure 2 shows that increasing forces the trajectory to the region with smaller , and increasing does the opposite. The right part shows that increasing the learning rate moves Adam towards the region with smaller if (just like in the case of GD, except the norm is different if is small compared to gradient components), and does the opposite if . All these observations are exactly what Theorem 3.1 predicts.
5 Numerical Experiments
As a first sanity check, we train a relatively small fully-connected neural network with around parameters on the first 10,000 images of MNIST with full-batch Adam for 100 epochs and plot the value , i. e. the maximal weight difference between the Adam iteration and the piecewise ODE solution.111Since it makes little sense to numerically solve an ODE by further discretization, is estimated using the iteration (11) with ignored. Strictly speaking, this is not the trajectory obtained by the final backward error analysis step but rather the step immediately preceding it (after removing long-term memory but before converting the iteration to an ODE). We see in Figure 3 that even on this very large time horizon the trajectories are close in infinity-norm.
Further, we offer some preliminary empirical evidence that Adam (anti-)penalizes the perturbed one-norm of the gradients, as discussed in Section 2.
Ma et al. (2022) divide training regimes of Adam into three categories: the spike regime when is much larger than , in which the training loss curve contains very large spikes and the training is obviously unstable; the (stable) oscillation regime when is sufficiently close to , in which the loss curve contains fast and small oscillations; the divergence regime when is much larger than , in which Adam diverges. We exclude the last regime. In the spike regime, the loss spikes to large values at irregular intervals. This has also been observed in the context of large transformers, and mitigation strategies have been proposed in Chowdhery et al. (2022) and Molybog et al. (2023). Since it is unlikely that an unstable Adam trajectory can be meaningfully approximated by a smooth ODE solution, we reduce the incidence of large spikes by only considering and that are not too far apart, which is what Ma et al. (2022) recommend to do in practice.
We train Resnet-50, CNNs and Vision Transformers (Dosovitskiy et al., 2020) on the CIFAR-10 dataset with full-batch Adam. In this section, we provide the results for Resnet-50; the pictures for CNNs and Transformers are similar and are given in Section H.4. Figure 4 shows that in the stable oscillation regime increasing appears to increase the perturbed one-norm (consistent with our analysis: the smaller , the more this “norm” is penalized) and decrease the test accuracy. Figure 5 shows that increasing appears to decrease the perturbed one-norm (consistent with our analysis: the larger , the more this norm is penalized) and increase the test accuracy. The picture confirms the finding in Ghosh et al. (2023) (for the momentum parameter in momentum GD).
Figure 6 shows the graphs of as functions of the epoch number. The “norm” decreases, then rises again, and then decreases further until it flatlines.222Note that the perturbed one-norm cannot be near-zero at the end of training because it is bounded from below by . Throughout most of the training, the larger the smaller the “norm”. The “hills” of the “norm” curves are higher with smaller and larger . This is consistent with our analysis because the larger compared to , the more is prevented from falling by the correction term.
6 Limitations and Future Directions
As far as we know, the assumption similar to (8) is explicitly or implicitly present in all previous work on backward error analysis of gradient-based machine learning algorithms. (Recently, Beneventano (2023) weakened this assumption for SGD without replacement, but their focus is somewhat different.) There is evidence that large-batch algorithms often operate near or at the edge of stability (Cohen et al., 2021, 2022), in which the largest eigenvalue of the hessian can be large, making it unclear whether the higher-order partial derivatives can safely be assumed bounded near optimality. In addition, as Smith et al. (2021) point out, in the mini-batch setting backward error analysis can be more accurate. We leave a qualitative analysis of the behavior of first-order terms in Theorem 3.1 in the mini-batch case as a future direction.
Relatedly, Adam does not always generalize worse than SGD: for transformers, Adam often outperforms (Zhang et al., 2020; Kumar et al., 2022). Moreover, for NLP tasks a long time can be spent training close to an interpolating solution. Our analysis suggests that in the latter regime the anti-regularization effect disappears, which does indeed confirm the finding that generalization can be better. However, we believe this explanation is not complete, and more work is needed to connect the implicit bias to the training dynamics of transformers.
In addition, the constant in Theorem 3.1 goes to infinity as goes to zero. Theoretically, our proof does not exclude the case where for very small the trajectory of the piecewise ODE is only close to the Adam trajectory for small, suboptimal learning rates, at least at later stages of learning. (For the initial learning period, this is not a problem.) It appears to also be true of Proposition 1 in Ma et al. (2022) (zeroth-order approximation by sign-GD). This is especially noticeable in the large-spike regime of training (see Section 5) which, despite being obviously unstable, can still lead to acceptable test errors. It would be worthwhile to investigate this regime in detail.
Acknowledgments
We extend our special thanks to Boris Hanin, Sam Smith, and the anonymous reviewers for their insightful comments and suggestions that greatly enhanced this work. We are also grateful to Jianqing Fan, Pier Beneventano, and Rae Yu for engaging and productive discussions. Cattaneo gratefully acknowledges financial support from the National Science Foundation through DMS-2210561 and SES-2241575. Klusowski gratefully acknowledges financial support from the National Science Foundation through CAREER DMS-2239448. Additionally, we acknowledge the Princeton Research Computing resources, coordinated by the Princeton Institute for Computational Science and Engineering (PICSciE) and the Office of Information Technology’s Research Computing.
Impact Statement
This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.
References
- Andriushchenko et al. (2023) Andriushchenko, M., Croce, F., Müller, M., Hein, M., and Flammarion, N. A modern look at the relationship between sharpness and generalization. In Proceedings of the 40th International Conference on Machine Learning, ICML’23. JMLR.org, 2023.
- Balles et al. (2020) Balles, L., Pedregosa, F., and Roux, N. L. The geometry of sign gradient descent. arXiv preprint arXiv:2002.08056, 2020.
- Barakat & Bianchi (2021) Barakat, A. and Bianchi, P. Convergence and dynamical behavior of the adam algorithm for nonconvex stochastic optimization. SIAM Journal on Optimization, 31(1):244–274, 2021.
- Barrett & Dherin (2021) Barrett, D. and Dherin, B. Implicit gradient regularization. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=3q5IqUrkcF.
- Beneventano (2023) Beneventano, P. On the trajectories of sgd without replacement. arXiv preprint arXiv:2312.16143, 2023.
- Bernstein et al. (2018) Bernstein, J., Wang, Y.-X., Azizzadenesheli, K., and Anandkumar, A. signsgd: Compressed optimisation for non-convex problems. In International Conference on Machine Learning, pp. 560–569. PMLR, 2018.
- Beyer et al. (2022) Beyer, L., Zhai, X., and Kolesnikov, A. Better plain vit baselines for imagenet-1k. arXiv preprint arXiv:2205.01580, 2022.
- Chen et al. (2018) Chen, J., Zhou, D., Tang, Y., Yang, Z., Cao, Y., and Gu, Q. Closing the generalization gap of adaptive gradient methods in training deep neural networks. arXiv preprint arXiv:1806.06763, 2018. URL https://arxiv.org/pdf/1806.06763.
- Chen et al. (2021) Chen, X., Hsieh, C.-J., and Gong, B. When vision transformers outperform resnets without pre-training or strong data augmentations. arXiv preprint arXiv:2106.01548, 2021.
- Chowdhery et al. (2022) Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H. W., Sutton, C., Gehrmann, S., et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Cohen et al. (2021) Cohen, J., Kaur, S., Li, Y., Kolter, J. Z., and Talwalkar, A. Gradient descent on neural networks typically occurs at the edge of stability. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=jh-rTtvkGeM.
- Cohen et al. (2022) Cohen, J. M., Ghorbani, B., Krishnan, S., Agarwal, N., Medapati, S., Badura, M., Suo, D., Cardoze, D., Nado, Z., Dahl, G. E., et al. Adaptive gradient methods at the edge of stability. arXiv preprint arXiv:2207.14484, 2022.
- Dosovitskiy et al. (2020) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
- Ernst Hairer & Wanner (2006) Ernst Hairer, C. L. and Wanner, G. Geometric numerical integration. Springer-Verlag, Berlin, 2 edition, 2006. ISBN 3-540-30663-3.
- Even et al. (2023) Even, M., Pesme, S., Gunasekar, S., and Flammarion, N. (s) gd over diagonal linear networks: Implicit regularisation, large stepsizes and edge of stability. arXiv preprint arXiv:2302.08982, 2023.
- Farazmand (2020) Farazmand, M. Multiscale analysis of accelerated gradient methods. SIAM Journal on Optimization, 30(3):2337–2354, 2020.
- Foret et al. (2021) Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=6Tm1mposlrM.
- França et al. (2021) França, G., Jordan, M. I., and Vidal, R. On dissipative symplectic integration with applications to gradient-based optimization. Journal of Statistical Mechanics: Theory and Experiment, 2021(4):043402, 2021.
- Ghosh et al. (2023) Ghosh, A., Lyu, H., Zhang, X., and Wang, R. Implicit regularization in heavy-ball momentum accelerated stochastic gradient descent. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=ZzdBhtEH9yB.
- Granziol (2020) Granziol, D. Flatness is a false friend. arXiv preprint arXiv:2006.09091, 2020.
- Gunasekar et al. (2018a) Gunasekar, S., Lee, J., Soudry, D., and Srebro, N. Characterizing implicit bias in terms of optimization geometry. In International Conference on Machine Learning, pp. 1832–1841. PMLR, 2018a.
- Gunasekar et al. (2018b) Gunasekar, S., Lee, J. D., Soudry, D., and Srebro, N. Implicit bias of gradient descent on linear convolutional networks. Advances in neural information processing systems, 31, 2018b.
- He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.
- Hoffer et al. (2017) Hoffer, E., Hubara, I., and Soudry, D. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. Advances in neural information processing systems, 30, 2017.
- Ji & Telgarsky (2018a) Ji, Z. and Telgarsky, M. Gradient descent aligns the layers of deep linear networks. arXiv preprint arXiv:1810.02032, 2018a.
- Ji & Telgarsky (2018b) Ji, Z. and Telgarsky, M. Risk and parameter convergence of logistic regression. arXiv preprint arXiv:1803.07300, 2018b.
- Ji & Telgarsky (2019) Ji, Z. and Telgarsky, M. The implicit bias of gradient descent on nonseparable data. In Conference on Learning Theory, pp. 1772–1798. PMLR, 2019.
- Ji & Telgarsky (2020) Ji, Z. and Telgarsky, M. Directional convergence and alignment in deep learning. Advances in Neural Information Processing Systems, 33:17176–17186, 2020.
- Jiang et al. (2022) Jiang, K., Malik, D., and Li, Y. How does adaptive optimization impact local neural network geometry? arXiv preprint arXiv:2211.02254, 2022.
- Keskar & Socher (2017) Keskar, N. S. and Socher, R. Improving generalization performance by switching from adam to sgd. arXiv preprint arXiv:1712.07628, 2017.
- Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
- Kovachki & Stuart (2021) Kovachki, N. B. and Stuart, A. M. Continuous time analysis of momentum methods. Journal of Machine Learning Research, 22(17):1–40, 2021.
- Kumar et al. (2022) Kumar, A., Shen, R., Bubeck, S., and Gunasekar, S. How to fine-tune vision models with sgd. arXiv preprint arXiv:2211.09359, 2022.
- Kunin et al. (2020) Kunin, D., Sagastuy-Brena, J., Ganguli, S., Yamins, D. L., and Tanaka, H. Neural mechanics: Symmetry and broken conservation laws in deep learning dynamics. arXiv preprint arXiv:2012.04728, 2020.
- Lee et al. (2015) Lee, C.-Y., Xie, S., Gallagher, P., Zhang, Z., and Tu, Z. Deeply-supervised nets. In Artificial intelligence and statistics, pp. 562–570. Pmlr, 2015.
- Li et al. (2017) Li, Q., Tai, C., and E, W. Stochastic modified equations and adaptive stochastic gradient algorithms. In Precup, D. and Teh, Y. W. (eds.), Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pp. 2101–2110. PMLR, 8 2017. URL https://proceedings.mlr.press/v70/li17f.html.
- Lyu & Li (2019) Lyu, K. and Li, J. Gradient descent maximizes the margin of homogeneous neural networks. arXiv preprint arXiv:1906.05890, 2019.
- Ma et al. (2022) Ma, C., Wu, L., and Weinan, E. A qualitative study of the dynamic behavior for adaptive gradient algorithms. In Mathematical and Scientific Machine Learning, pp. 671–692. PMLR, 2022.
- Malladi et al. (2022) Malladi, S., Lyu, K., Panigrahi, A., and Arora, S. On the sdes and scaling rules for adaptive gradient algorithms. Advances in Neural Information Processing Systems, 35:7697–7711, 2022.
- Miyagawa (2022) Miyagawa, T. Toward equation of motion for deep neural networks: Continuous-time gradient descent and discretization error analysis. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=qq84D17BPu.
- Molybog et al. (2023) Molybog, I., Albert, P., Chen, M., DeVito, Z., Esiobu, D., Goyal, N., Koura, P. S., Narang, S., Poulton, A., Silva, R., et al. A theory on adam instability in large-scale machine learning. arXiv preprint arXiv:2304.09871, 2023.
- Nacson et al. (2019a) Nacson, M. S., Gunasekar, S., Lee, J., Srebro, N., and Soudry, D. Lexicographic and depth-sensitive margins in homogeneous and non-homogeneous deep models. In International Conference on Machine Learning, pp. 4683–4692. PMLR, 2019a.
- Nacson et al. (2019b) Nacson, M. S., Lee, J., Gunasekar, S., Savarese, P. H. P., Srebro, N., and Soudry, D. Convergence of gradient descent on separable data. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 3420–3428. PMLR, 2019b.
- Nacson et al. (2019c) Nacson, M. S., Srebro, N., and Soudry, D. Stochastic gradient descent on separable data: Exact convergence with a fixed learning rate. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 3051–3059. PMLR, 2019c.
- Qian & Qian (2019) Qian, Q. and Qian, X. The implicit bias of adagrad on separable data. Advances in Neural Information Processing Systems, 32, 2019.
- Rosca et al. (2021) Rosca, M. C., Wu, Y., Dherin, B., and Barrett, D. Discretization drift in two-player games. In International Conference on Machine Learning, pp. 9064–9074. PMLR, 2021.
- Smith et al. (2021) Smith, S. L., Dherin, B., Barrett, D., and De, S. On the origin of implicit regularization in stochastic gradient descent. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=rq_Qr0c1Hyo.
- Soudry et al. (2018) Soudry, D., Hoffer, E., Nacson, M. S., Gunasekar, S., and Srebro, N. The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research, 19(1):2822–2878, 2018.
- Tieleman et al. (2012) Tieleman, T., Hinton, G., et al. Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2):26–31, 2012.
- Wang et al. (2021) Wang, B., Meng, Q., Chen, W., and Liu, T.-Y. The implicit bias for adaptive optimization algorithms on homogeneous neural networks. In Meila, M. and Zhang, T. (eds.), Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pp. 10849–10858. PMLR, 7 2021. URL https://proceedings.mlr.press/v139/wang21q.html.
- Wang et al. (2022) Wang, B., Meng, Q., Zhang, H., Sun, R., Chen, W., Ma, Z.-M., and Liu, T.-Y. Does momentum change the implicit regularization on separable data? Advances in Neural Information Processing Systems, 35:26764–26776, 2022.
- Wilson et al. (2017) Wilson, A. C., Roelofs, R., Stern, M., Srebro, N., and Recht, B. The marginal value of adaptive gradient methods in machine learning. Advances in neural information processing systems, 30, 2017.
- Woodworth et al. (2020) Woodworth, B., Gunasekar, S., Lee, J. D., Moroshko, E., Savarese, P., Golan, I., Soudry, D., and Srebro, N. Kernel and rich regimes in overparametrized models. In Conference on Learning Theory, pp. 3635–3673. PMLR, 2020.
- Xie et al. (2022) Xie, Z., Wang, X., Zhang, H., Sato, I., and Sugiyama, M. Adaptive inertia: Disentangling the effects of adaptive learning rate and momentum. In International conference on machine learning, pp. 24430–24459. PMLR, 2022.
- Zhang et al. (2020) Zhang, J., Karimireddy, S. P., Veit, A., Kim, S., Reddi, S., Kumar, S., and Sra, S. Why are adaptive methods good for attention models? Advances in Neural Information Processing Systems, 33:15383–15393, 2020.
- Zhao et al. (2022) Zhao, Y., Zhang, H., and Hu, X. Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, pp. 26982–26992. PMLR, 2022.
- Zhou et al. (2020) Zhou, P., Feng, J., Ma, C., Xiong, C., Hoi, S. C. H., et al. Towards theoretically understanding why sgd generalizes better than adam in deep learning. Advances in Neural Information Processing Systems, 33:21285–21296, 2020.
Appendix A Overview
The appendix provide some omitted details and proofs.
We consider two algorithms: RMSProp and Adam, and two versions of each algorithm (with the numerical stability parameter inside and outside of the square root in the denominator). This means there are four main theorems: Theorem B.4, Theorem C.4, Theorem D.4 and Theorem E.4, each residing in the section completely devoted to one algorithm. The simple induction argument taken from Ghosh et al. (2023), essentially the same for each of these theorems, is based on an auxiliary result whose corresponding versions are Theorem B.3, Theorem C.3, Theorem D.3 and Theorem E.3. The proof of this result is also elementary but long, and it is done by a series of lemmas in Appendix F and Appendix G. Out of these four, we only prove Theorem B.3 since the other three results are proven in the same way with obvious changes.
Appendix H contains some details about the numerical experiments.
A.1 Notation
We denote the loss of the th minibatch as a function of the network parameters by , and in the full-batch setting we omit the index and write . As usual, means the gradient of , and nabla with indices means partial derivatives, e. g. is a shortcut for .
The letter will always denote a finite time horizon of the ODEs, will always denote the training step size, and we will replace with when convenient, where is the step number. We will use the same notation for the iteration of the discrete algorithm , the piecewise ODE solution and some auxiliary terms for each of the four algorithms: see Definition B.1, Definition C.1, Definition D.1, Definition E.1. This way, we avoid cluttering the notation significantly. We are careful to reference the relevant definition in all theorem statements.
Appendix B RMSProp with Outside the Square Root
Definition B.1.
In this section, for some , , , let the sequence of -vectors be defined for by
| (12) |
Let be defined as a continuous solution to the piecewise ODE
| (13) |
for with the initial condition , where , and are -dimensional functions with components
Assumption B.2.
-
1.
For some positive constants , , , we have
-
2.
For some we have for all
where is defined in Definition B.1.
Theorem B.3(RMSProp with outside: local error bound).
Suppose B.2 holds. Then for all ,
for a positive constant depending on .
The proof of Theorem B.3 is conceptually simple but very technical, and we delay it until Appendix G. For now assuming it as given and combining it with a simple induction argument gives a global error bound which follows.
Theorem B.4(RMSProp with outside: global error bound).
Suppose B.2 holds, and
for defined in Definition B.1. Then there exist positive constants , , such that for all
where . The constants can be defined as
Proof.
We will show this by induction over , the same way an analogous bound is shown in Ghosh et al. (2023).
The base case is . Indeed, . Then the th component of is
By Theorem B.3, the absolute value of the right-hand side does not exceed , which means . Since , the base case is proven.
Now suppose that for all the claim
is proven. Then
where \labrel1 is by the triangle inequality, \labrel2 is by , in \labrel3 we used for all .
B.1 RMSProp with outside: full-batch
In the full-batch setting , the terms in (13) simplify to
If is small and the iteration number is large, (13) simplifies to
Appendix C RMSProp with Inside the Square Root
Definition C.1.
In this section, for some , , , let the sequence of -vectors be defined for by
| (19) |
Let be defined as a continuous solution to the piecewise ODE
| (20) |
for with the initial condition , where , and are -dimensional functions with components
| (21) |
Assumption C.2.
For some positive constants , , , we have
Theorem C.3(RMSProp with inside: local error bound).
Suppose C.2 holds. Then for all ,
for a positive constant depending on , where is defined in Definition C.1.
The argument is the same as for Theorem B.3.
Theorem C.4(RMSProp with inside: global error bound).
Suppose C.2 holds. Then there exist positive constants , , such that for all
where ; and are defined in Definition C.1. The constants can be defined as
The argument is the same as for Theorem B.4.
C.1 RMSProp with Inside: Full-Batch
Appendix D Adam with Outside the Square Root
Definition D.1.
In this section, for some , , , let the sequence of -vectors be defined for by
or, rewriting,
| (22) |
Let be defined as a continuous solution to the piecewise ODE
| (23) |
for with the initial condition , where , , , , , are -dimensional functions with components
| (24) |
Assumption D.2.
-
1.
For some positive constants , , , we have
-
2.
For some we have for all
where is defined in Definition D.1.
Theorem D.3(Adam with outside: local error bound).
Suppose D.2 holds. Then for all ,
for a positive constant depending on and .
The argument is the same as for Theorem B.3.
Theorem D.4(Adam with outside: global error bound).
Suppose D.2 holds, and
for defined in Definition D.1. Then there exist positive constants , , such that for all
where . The constants can be defined as
Proof.
Analogously to Theorem B.4, we will prove this by induction over .
The base case is . Indeed, . Then the th component of is
By Theorem D.3, the absolute value of the right-hand side does not exceed , which means . Since , the base case is proven.
Now suppose that for all the claim
is proven. Then
where \labrel1 is by the triangle inequality, \labrel2 is by , in \labrel3 we used for all .
Using , , we have
| (26) |
But since
we have
| (27) |
Similarly,
| (28) |
Appendix E Adam with Inside the Square Root
Definition E.1.
In this section, for some , , , let the sequence of -vectors be defined for by
| (31) |
Let be defined as a continuous solution to the piecewise ODE
| (32) |
for with the initial condition , where , , , , , are -dimensional functions with components
| (33) |
Assumption E.2.
For some positive constants , , , we have
Theorem E.3(Adam with inside: local error bound).
Suppose E.2 holds. Then for all ,
for a positive constant depending on and .
The argument is the same as for Theorem B.3.
Theorem E.4(Adam with inside: global error bound).
Suppose E.2 holds for defined in Definition E.1. Then there exist positive constants , , such that for all
where . The constants can be defined as
The argument is the same as for Theorem D.4.
Appendix F Bounding the Derivatives of the ODE Solution
Our first goal is to argue that the first derivative of is uniformly bounded in absolute value. To achieve this, we just need to bound all the terms on the right-hand side of the ODE (13).
Lemma F.1.
Lemma F.2.
Suppose B.2 holds. Then the first derivative of is uniformly over and bounded in absolute value by some positive constant, say .
Our next goal is to argue that the second derivative of is bounded in absolute value. For this, we need to bound the first derivatives of all the three additive terms on the right-hand side of (13).
Lemma F.3.
Suppose B.2 holds. Then for all , we have
| (36) |
| (37) |
| (38) |
| (39) |
| (40) |
| (41) |
| (42) |
| (43) |
with constants , , , , , , , , , , , defined as follows:
Proof of Lemma F.3.
We prove the inequalities one by one.
The bound (36) is straightforward:
The bound (38) follows from the assumptions immediately.
We will prove (39) by bounding the two additive terms on the right-hand side of the equality
| (44) |
It is easily shown that the first term in (44) is bounded in absolute value by :
For the proof of (39), it is left to show that the second term in (44) is bounded in absolute value by .
To bound , we can use
By the Cauchy-Schwarz inequality applied twice,
Next, for any and
| (45) |
This gives
We have obtained
| (46) |
We will prove (40) by bounding the four terms in the expression
where
To bound Term1, use , giving
To bound Term2, use , giving
To bound Term3, use , giving
Combining two bounds above, we have
We are ready to conclude
It is left to prove (43). Since
and, as we have already seen in the argument for (40),
we are ready to bound
The proof of Lemma F.3 is concluded. ∎
Lemma F.4.
Suppose B.2 holds. Then the second derivative of is uniformly over and bounded in absolute value by some positive constant, say .
Proof.
Finally, we need to argue that the third derivative of is bounded in absolute value. To achieve this, we need to bound the second derivatives of the terms on the right-hand side of (13).
Lemma F.5.
Suppose B.2 holds. Then for all ,
| (50) |
| (51) |
| (52) |
| (53) |
| (54) |
| (55) |
with constants , , , , , defined as follows:
Proof of Lemma F.5.
We prove the inequalities one by one.
The proof of (50) is straightforward:
To justify (54), put temporarily , and use
combined with
The proof of Lemma F.5 is concluded. ∎
Lemma F.6.
Suppose B.2 holds. Then the third derivative of is uniformly over and bounded in absolute value by some positive constant, say .
Proof.
From the definition of , it means that its derivatives up to order two are bounded. Similarly, the same is true for .
It follows from (52) and its proof that the derivatives up to order two of
are also bounded.
Appendix G Proof of Theorem B.3
Our next objective is proving and identifying the constant in the equality
We will make some preparations and achieve this objective in Lemma G.5. Then we will conclude the proof of Theorem B.3.
Lemma G.1.
Suppose B.2 holds. Then for all , , we have
| (56) |
Proof.
(56) follows from the mean value theorem applied times. ∎
Lemma G.2.
In the setting of Lemma G.1, for any we have
Proof.
By the Taylor expansion of on the segment at on the left
Combining this with (37) gives
| (57) |
Now applying the mean-value theorem times, we have by (46)
and in particular
Lemma G.3.
In the setting of Lemma G.1,
Lemma G.4.
Proof.
We continue by showing
| (60) |
To prove this, use
with
and bounding
Lemma G.5.
Suppose B.2 holds. Then
Proof.
We are finally ready to prove Theorem B.3.
Proof of Theorem B.3.
By (42) and (43), the first derivative of the function
is bounded in absolute value by a positive constant . By (13), this means
Combining this with
by Taylor expansion, we get
| (61) |
Using
with defined as
by (13), and calculating the derivative, it is easy to show
| (62) |
for a positive constant , where
It is left to combine this with Lemma G.5, giving the assertion of the theorem with
Appendix H Numerical Experiments
H.1 Models
We use small modifications of Resnet-50 and Resnet-101 implementations in the torchvision library for training on CIFAR-10 and CIFAR-100. The first convolution layer conv1 has kernel, stride 1 and “same” padding. Then comes batch normalization, and relu. Max pooling is removed, and otherwise conv2_x to conv5_x are as described in He et al. (2016) (see Table 1 there) except downsampling is performed by the middle convolution of each bottleneck block, as in version 1.5333https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch.
After conv5 there is global average pooling and 10 or 100-way fully connected layer (for CIFAR-10 and CIFAR-100 respectively).
The MLP that we use for showing the closeness of trajectories in Figure 3 consists of two fully connected layers, each with 32 units and GeLU activation, followed by a fully-connected layer with 10 units.
In Figure 3, the curves called “first order” plot and the curves called “second order” plot , where is the Adam iteration defined in Definition 1.1 and
| (63) |
for and as defined in Section 3, with the same initial point .
H.2 Data Augmentation
We subtract the per-pixel mean and divide by standard deviation, and we use the data augmentation scheme from Lee et al. (2015), following He et al. (2016), section 4.2. During each pass over the training dataset, each initial image is padded evenly with zeros so that it becomes , then random crop is applied so that the picture becomes again, and random (probability ) horizontal (left to right) flip is used.
H.3 Experiment Details
In experiments whose results are reported in Figures 4 and 5 of the main paper, we train for a few thousand epochs and stop training when the train accuracy is near-perfect (Figure 11) and the testing accuracy does not significantly improve (Figure 12). Therefore, the maximal test accuracies are the final ones reached, and the maximal perturbed one-norms, after excluding the initial fall at the beginning of training, are at peaks of the “hills” on the norm curves (Figure 12).
Since the full dataset does not fit into GPU memory, we divide it into 100 “ghost batches” and accumulate the gradients before doing one optimization step. This means that we use ghost batch normalization (Hoffer et al., 2017) as opposed to full-dataset batch normalization, similarly to Cohen et al. (2021).
H.4 Additional Evidence
We provide evidence that the results in Figures 4 and 5 are robust to the change of architectures. In Figures 7 and 8, we show that the pictures are similar for a simple CNN created by the following code: {minted}[breaklines]python layers = [ # First block nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
# Second block nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
# Third block nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
# Flatten and Dense layers nn.Flatten(), nn.Linear(in_features=128 * 4 * 4, out_features=512), nn.ReLU(), nn.Linear(in_features=512, out_features=num_classes), ] return nn.Sequential(*layers) In Figures 9 and 10, we show that the same conclusions can be made for a Vision Transformer (Dosovitskiy et al., 2020; Beyer et al., 2022). In these experiments, we use the SimpleViT architecture from the vit-pytorch library with patches, 6 transformer blocks with 16 heads, embedding size 512 and MLP dimension of 1024 (following Andriushchenko et al. (2023)).
Appendix I Adam with Inside the Square Root: Informal Derivation
Our goal is to find such a trajectory that
Result I.1.
For we have
| (64) |
Derivation.
We take
for granted. Using this and the Taylor series, we can write
where in the last equality we just replaced with in the -term since it only affects higher-order terms. Now doing this again for step instead of step , we will have
where in the last equality we again replaced with since it only affects higher-order terms. Proceeding like this and adding the resulting equations, we have for , that
where we ignored the fact that is not bounded (we will get away with this because of exponential averaging). Hence, taking the square of this formal power series,
Summing up over , we have
which, using the expression for the inverse square root of a formal power series , gives us
Similarly,
We conclude
Result I.2.
For , the modified equation is (32).
