VOOZH about

URL: https://arxiv.org/html/2309.00079v4

⇱ On the Implicit Bias of Adam


On the Implicit Bias of Adam

Matias D. Cattaneo Jason M. Klusowski Boris Shigida
Abstract

In previous literature, backward error analysis was used to find ordinary differential equations (ODEs) approximating the gradient descent trajectory. It was found that finite step sizes implicitly regularize solutions because terms appearing in the ODEs penalize the two-norm of the loss gradients. We prove that the existence of similar implicit regularization in RMSProp and Adam depends on their hyperparameters and the training stage, but with a different “norm” involved: the corresponding ODE terms either penalize the (perturbed) one-norm of the loss gradients or, conversely, impede its reduction (the latter case being typical). We also conduct numerical experiments and discuss how the proven facts can influence generalization.

theory, optimization, implicit bias, generalization, backward error analysis, modified equations, Adam, adaptive methods, gradient descent, sharpness
\NewDocumentCommand\labrel

m o \IfNoValueTF#2(‘#1)#2(‘#1) \NewDocumentCommand\newco\IfValueT#1 \NewDocumentCommand\newbigco\IfValueT#1 \NewDocumentCommand\newdo\IfValueT#1 \NewDocumentCommand\newbigdo\IfValueT#1


1 Introduction

Gradient descent (GD) can be seen as a numerical method solving the ordinary differential equation (ODE) , where is the loss function and is its gradient. Starting at , it creates a sequence of guesses , which lie close to the solution trajectory governed by the aforementioned ODE. Since the step size is finite, one could search for a modified differential equation such that is exactly zero, or at least closer to zero than , that is, all the guesses of the descent lie exactly on the new solution curve or closer compared to the original curve. This approach to analysing properties of a numerical method is sometimes called backward error analysis in the numerical integration literature (see Chapter IX in Ernst Hairer & Wanner (2006) and references therein).

Barrett & Dherin (2021) used this idea for full-batch GD and found that the modified loss function makes the trajectory of the solution to approximate the sequence one order of better than the original ODE, where is the Euclidean norm. In related work, Miyagawa (2022) obtained the correction term for full-batch GD up to any chosen order, also studying the global error (uniform in the iteration number) as opposed to the local (one-step) error.

The analysis was later extended to mini-batch GD in Smith et al. (2021). Assume that the training set is split into batches of size and there are batches per epoch (so the training set size is ). The cost function is rewritten as with mini-batch costs denoted . It was obtained in that work that after one epoch, the mean iterate of the algorithm, averaged over all possible shuffles of the batch indices, is close to the solution to , where the modified loss is given by .

Modified equations have also been derived for GD with heavy-ball momentum , where is the momentum parameter. In the full-batch setting, it turns out that for large enough it is close to the continuous trajectory solving

(1)

Versions of this general result were proven in Farazmand (2020), Kovachki & Stuart (2021), Ghosh et al. (2023) under different assumptions. The focus of the latter work is the closest to ours since they interpret the correction term as implicit regularization. Their main theorem also provides the analysis for the general mini-batch case.

In another recent work, Zhao et al. (2022) introduce a regularization term to the loss function as a way to ensure finding flatter minima, improving generalization. The only difference between their term and the first-order correction coming from backward error analysis (up to a coefficient) is that the norm is not squared and regularization is applied on a per-batch basis.

The application of backward error analysis for approximating the discrete dynamics of adaptive algorithms such as RMSProp (Tieleman et al., 2012) and Adam (Kingma & Ba, 2015) is currently missing in the literature. Barrett & Dherin (2021) note that “it would be interesting to use backward error analysis to calculate the modified loss and implicit regularization for other widely used optimizers such as momentum, Adam and RMSprop”. Smith et al. (2021) reiterate that they “anticipate that backward error analysis could also be used to clarify the role of finite learning rates in adaptive optimizers like Adam”. Ghosh et al. (2023) agree that “RMSProp … and Adam …, albeit being powerful alternatives to SGD with faster convergence rates, are far from well-understood in the aspect of implicit regularization”. In a similar context, in Appendix G to Miyagawa (2022) it is mentioned that “its [Adam’s] counter term and discretization error are open questions”.

This work fills the gap by conducting backward error analysis for (mini-batch, and full-batch as a special case) Adam and RMSProp. Our main contributions are listed below.

  • In Theorem 3.1, we provide a global second-order in continuous piecewise ODE approximation to Adam in the general mini-batch setting. (A similar result for RMSProp is moved to Appendix C.) For the full-batch special case, it was shown in prior work Ma et al. (2022) that the continuous-time limit of both these algorithms is a (perturbed by the numerical stability parameter ) signGD flow component-wise; we make this more precise by finding a linear in correction term on the right.

  • We analyze the full-batch case in the context of regularization (see the summary in Section 2). In contrast to the case of GD, where the two-norm of the loss gradient is implicitly penalized, Adam typically anti-penalizes the perturbed one-norm of the loss gradient (i. e., penalizes the negative norm), as specified in (5). Thus, the implicit bias of Adam that we identify serves as anti-regularization (except for the unusual case , large or very late at training).

  • We provide numerical evidence consistent with our theoretical results by training various vision models on CIFAR-10 using full-batch Adam. In particular, we observe that the stronger the implicit anti-regularization effect predicted by our theory, the worse the generalization. This pattern holds across different architectures: ResNets, simple convolutional neural networks (CNNs) and Vision Transformers. Thus, we propose a novel possible explanation for often-reported poor generalization of adaptive gradient algorithms. The code used for training the models is available at https://github.com/borshigida/implicit-bias-of-adam.

1.1 Related Work

Backward error analysis of first-order methods. We outlined the history of finding ODEs approximating different algorithms above in the introduction. Recently, there have been other applications of backward error analysis related to machine learning. Kunin et al. (2020) show that the approximating continuous-time trajectories satisfy conservation laws that are broken in discrete time. França et al. (2021) use backward error analysis while studying how to discretize continuous-time dynamical systems preserving stability and convergence rates. Rosca et al. (2021) find continuous-time approximations of discrete two-player differential games.

Approximating gradient methods by differential equation trajectories. Under the assumption that the hyperparameters of the Adam algorithm (see Definition 1.1) tend to 1 at a certain rate as , a first-order continuous ODE approximation to this algorithm was derived in Barakat & Bianchi (2021). On the other hand, if are kept fixed, Ma et al. (2022) prove that the trajectories of Adam and RMSProp are close to signGD dynamics, and investigate different training regimes of these algorithms empirically. SGD is approximated by stochastic differential equations and novel adaptive parameter adjustment policies are devised in Li et al. (2017). Malladi et al. (2022) derive stochastic differential equations that are order-1 weak approximations of RMSProp and Adam. We go in a different direction: instead of clarifying the previously obtained continuous ODE approximations by taking gradient noise into account, we take a deterministic approach but go one order of further. In particular, we keep , fixed (thus generalizing the analysis for SGD with momentum), whereas Malladi et al. (2022) take .

Connection with signGD. The connection of adaptive gradient methods with sign(S)GD is extensively discussed in Bernstein et al. (2018). Balles et al. (2020) study a version of signGD with an update proportional to as a special case of steepest descent, and discuss when sign-based methods are preferable to GD.

Implicit bias of first-order methods. Soudry et al. (2018) prove that GD trained to classify linearly separable data with logistic loss converges to the direction of the max-margin vector (the solution to the hard margin SVM). This result has been extended to different loss functions in Nacson et al. (2019b), to SGD in Nacson et al. (2019c), AdaGrad in Qian & Qian (2019), (S)GD with momentum, deterministic Adam and stochastic RMSProp in Wang et al. (2022), more generic optimization methods in Gunasekar et al. (2018a), to the nonseparable case in Ji & Telgarsky (2018b), Ji & Telgarsky (2019). This line of research has been generalized to studying implicit biases of linear networks (Ji & Telgarsky, 2018a; Gunasekar et al., 2018b), homogeneous neural networks (Ji & Telgarsky, 2020; Nacson et al., 2019a; Lyu & Li, 2019). Woodworth et al. (2020) study the gradient flow of a diagonal linear network with squared loss and show that large initializations lead to minimum two-norm solutions while small initializations lead to minimum one-norm solutions. Even et al. (2023) extend this work to the case of non-zero step sizes and mini-batch training. Wang et al. (2021) prove that Adam and RMSProp maximize the margin of homogeneous neural networks. Our perspective on the implicit bias is different since we are considering a generic loss function without any assumptions on the network architecture. Beneventano (2023) proves that in expectation over batch sampling the trajectory of SGD without replacement differs from that of SGD with replacement by an additional step on a regularizer. As opposed to the work on backward error analysis for SGD discussed above, they do not assume the largest eigenvalue of the hessian to be bounded.

Generalization of adaptive methods. Cohen et al. (2022) investigate the edge-of-stability regime of adaptive gradient algorithms and the effect of sharpness (the largest eigenvalue of the hessian) on generalization. Granziol (2020); Chen et al. (2021) observe that adaptive methods find sharper minima than SGD and Zhou et al. (2020); Xie et al. (2022) argue theoretically that it is the case. Jiang et al. (2022) introduce a statistic that measures the uniformity of the hessian diagonal and argue that adaptive gradient algorithms are biased towards making this statistic smaller. Keskar & Socher (2017) propose to improve generalization of adaptive methods by switching to SGD in the middle of training.

1.2 Notation

We denote the loss of the th minibatch as a function of the network parameters by , and in the full-batch setting we omit the index and write . means the gradient of , and with indices denotes partial derivatives, e. g. is a shortcut for . The norm notation without indices is the two-norm of a vector, is the one-norm and is the perturbed one-norm defined as . (Of course, if the perturbed one-norm is not really a norm, but taking makes it the one-norm.) For a real number the floor is the largest integer not exceeding .

To provide the names and notations for hyperparameters, we define the algorithm below.

Definition 1.1.

The Adam algorithm (Kingma & Ba, 2015) is an optimization algorithm with numerical stability hyperparameter , squared gradient momentum hyperparameter , gradient momentum hyperparameter , initialization , , and the following update rule: for each ,

Remark 1.2.

Note that the numerical stability hyperparameter , which is introduced in these algorithms to avoid division by zero, is inside the square root in our definition. This way we avoid division by zero in the derivative too: the first derivative of is bounded for . This is useful for our analysis. In Theorems B.4 and D.4, the original versions of RMSProp and Adam are also tackled, though with an additional assumption which requires that no component of the gradient can come very close to zero in the region of interest. This is true only for the initial period of learning (whereas Theorem 3.1 tackles the whole period). Practitioners do not seem to make a distinction between the version with inside vs. outside the square root: tutorials with both versions abound on machine learning related websites. Moreover, the popular Tensorflow and Optax variants of RMSProp have inside the square root. Empirically we also observed that moving inside or outside the square root does not change the behavior of Adam or RMSProp qualitatively.

2 Implicit Bias of Full-Batch Adam: an Informal Summary

We are ready to describe our theoretical result (Theorem 3.1 below) in the full-batch special case. Assume is the loss, whose partial derivatives up to the fourth order are bounded. Let be iterations of Adam as defined in Definition 1.1. We find an ODE whose solution trajectory is -close to , meaning that for any time horizon there is a constant such that for any step size we have (for between and ). The ODE is written the following way (up to terms that rapidly go to zero as grows): for the component number

(2)

with initial conditions for all , where the correction term is

(3)

Depending on hyperparameters and the training stage, the correction term can take two extreme forms listed below. The reality is in between, but typically much closer to the first case.

  • If is small compared to all components of , i. e. , which is usually the case during most of the training, then we can write

    (4)

    For small , the perturbed one-norm is indistinguishable from the usual one-norm, and for it is penalized (in much the same way as the squared two-norm is implicitly penalized in the case of GD), but for the typical case its decrease is actually hindered by this term (so the bias is anti-regularization). The ODE in (2) approximately becomes

    (5)
  • If is large compared to all gradient components, i. e. (which may happen during the late learning stage, or if non-standard hyperparameter values are chosen), the fraction in (3) with in the numerator approaches one, the dependence on cancels out, and

    (6)

    In other words, becomes up to an additive constant, giving

    The form of the ODE in this case is

    (7)

These two extreme cases are summarized in Table 1. In Figure 1, we use the one-dimensional () case to illustrate what kind of term is being implicitly penalized.

Table 1: Implicit bias of Adam: special cases. “Small” and “large” are in relation to squared gradient components (Adam in the latter case is close to GD with momentum).

“small” “large”
-penalized -penalized
-penalized -penalized

Usually, is chosen to be small, and during most of the training Adam is much better described by the first extreme case. It is clear from (5) that, if , the correction term provides the opposite of regularization, in contrast to (1). The larger compared to , the stronger the anti-regularization effect is.

This finding may partially explain why adaptive gradient methods have been reported to generalize worse than non-adaptive ones (Chen et al., 2018; Wilson et al., 2017), as it offers a previously unknown perspective on why they are biased towards “higher-curvature” regions and find “sharper” minima. Indeed, note that standard (non-adaptive) -sharpness at can be defined by for some radius . This or similar definitions have been considered often in literature, see, e. g., Andriushchenko et al. (2023), Foret et al. (2021). Replacing the difference of the losses with its first-order approximation under the maximum (Foret et al., 2021; Ghosh et al., 2023)

we see that Adam typically anti-penalizes the approximation of -sharpness. Although the connection between sharpness and generalization is not clear-cut (Andriushchenko et al., 2023), our empirical results (Section 5) are consistent with this theory.

This overview also applies to RMSProp by setting ; see Theorem C.4 for the formal result.

👁 Refer to caption

Figure 1: To illustrate what term is being implicitly penalized in the simple case , we plot the graphs of with . In this case, the correction term in (3) is itself the gradient of the function , where is the derivative (=gradient) of the loss: specifically, . Hence, Adam’s iteration penalizes . If is small and , the negative one-norm of the gradient is penalized (leftmost picture, highest values of ); in other words, the one-norm is anti-penalized.
Example 2.1(Backward Error Analysis for GD with Heavy-ball Momentum).

Assume is large compared to all squared gradient components during the whole training process, so that the form of the ODE is approximated by (7). Since Adam with a large and after a certain number of iterations approximates SGD with heavy-ball momentum with step size , a linear step size change (and corresponding time change) gives exactly the equations in Theorem 4.1 of Ghosh et al. (2023). Taking (no momentum), we get the implicit regularization of GD from Barrett & Dherin (2021).

3 Main Result: ODE Approximating Mini-Batch Adam

We only make one assumption, which is standard in the literature: the loss for each mini-batch is 4 times continuously differentiable, and partial derivatives of up to order 4 are bounded, i. e. there is a positive constant such that for in the region of interest

(8)
Theorem 3.1.

Assume (8) holds. Let be iterations of Adam as defined in Definition 1.1, be the continuous solution to the piecewise ODE

(9)

for with the initial condition , where

Then, for any fixed positive time horizon there exists a constant (depending on , , , ) such that for any step size we have for .

Remark 3.2.

In the full-batch setting , the terms above simplify to

If the iteration number is large, (9) rapidly becomes as described in (2) and (3).

Derivation sketch.

The proof is in the appendix (this is Theorem E.4; see Appendix A for the overview of the appendix). To help the reader understand how the ODE (9) is obtained, apart from the full proof, we include an informal derivation in Appendix I, and provide an even briefer sketch of this derivation here.

Our goal is to find such a trajectory that, denoting , we have

(10)

Ignoring the terms of order higher than one in , we can take a first-order approximation for granted: with . The challenge is to make this more precise by finding an equality of the form

(11)

where is a known function. This is a numerical iteration to which standard backward error analysis (Chapter IX in Ernst Hairer & Wanner (2006)) can be applied.

Using the Taylor series, we can write

where in the last equality we just replaced with in the -term since it only affects higher-order terms. Doing this again for steps , , , and adding the resulting equations will give for

where we could safely ignore that is not bounded because of exponential averaging. Taking the square of this formal power series in , multiplying this square by and summing up over will give

which, using the expression for the inverse square root of a formal power series , gives us an expansion

A similar process provides an expansion for :

Inserting these two expansions into (10) leads to an expression for :

We are now ready to find an ODE for of the form whose discretization is (11). This is a task for standard backward error analysis: expand into . By Taylor expansion, we have

It is left to equate the terms before the corresponding powers of here and in (11), giving and . Omitting some algebra, the piecewise ODE (9) is derived. ∎

4 Illustration: Simple Bilinear Model

We now analyze the effect of the first-order term for Adam in the same model as Barrett & Dherin (2021) and Ghosh et al. (2023) have studied. Namely, assume the parameter is 2-dimensional, and the loss is given by . The loss is minimized on the hyperbola . We graph the trajectories of Adam in this case: the left part of Figure 2 shows that increasing forces the trajectory to the region with smaller , and increasing does the opposite. The right part shows that increasing the learning rate moves Adam towards the region with smaller if (just like in the case of GD, except the norm is different if is small compared to gradient components), and does the opposite if . All these observations are exactly what Theorem 3.1 predicts.

👁 Refer to caption

Figure 2: Left: increasing moves the trajectory of Adam towards the regions with smaller one-norm of the gradient (if is sufficiently small); increasing does the opposite. Right: increasing the learning rate moves the Adam trajectory towards the regions with smaller one-norm of the gradient if is significantly larger than and does the opposite if is larger than . The cross denotes the limit point of gradient one-norm minimizers on the level sets . The minimizers are drawn with a dashed line. All Adam trajectories start at .

5 Numerical Experiments

As a first sanity check, we train a relatively small fully-connected neural network with around parameters on the first 10,000 images of MNIST with full-batch Adam for 100 epochs and plot the value , i. e. the maximal weight difference between the Adam iteration and the piecewise ODE solution.111Since it makes little sense to numerically solve an ODE by further discretization, is estimated using the iteration (11) with ignored. Strictly speaking, this is not the trajectory obtained by the final backward error analysis step but rather the step immediately preceding it (after removing long-term memory but before converting the iteration to an ODE). We see in Figure 3 that even on this very large time horizon the trajectories are close in infinity-norm.

👁 Refer to caption

Figure 3: for a MLP trained with full-batch Adam on truncated MNIST, where is either first (signGD perturbed by ) or second order approximation to Adam; , , . Precise definitions are provided in Appendix H, specifically (63).

Further, we offer some preliminary empirical evidence that Adam (anti-)penalizes the perturbed one-norm of the gradients, as discussed in Section 2.

Ma et al. (2022) divide training regimes of Adam into three categories: the spike regime when is much larger than , in which the training loss curve contains very large spikes and the training is obviously unstable; the (stable) oscillation regime when is sufficiently close to , in which the loss curve contains fast and small oscillations; the divergence regime when is much larger than , in which Adam diverges. We exclude the last regime. In the spike regime, the loss spikes to large values at irregular intervals. This has also been observed in the context of large transformers, and mitigation strategies have been proposed in Chowdhery et al. (2022) and Molybog et al. (2023). Since it is unlikely that an unstable Adam trajectory can be meaningfully approximated by a smooth ODE solution, we reduce the incidence of large spikes by only considering and that are not too far apart, which is what Ma et al. (2022) recommend to do in practice.

We train Resnet-50, CNNs and Vision Transformers (Dosovitskiy et al., 2020) on the CIFAR-10 dataset with full-batch Adam. In this section, we provide the results for Resnet-50; the pictures for CNNs and Transformers are similar and are given in Section H.4. Figure 4 shows that in the stable oscillation regime increasing appears to increase the perturbed one-norm (consistent with our analysis: the smaller , the more this “norm” is penalized) and decrease the test accuracy. Figure 5 shows that increasing appears to decrease the perturbed one-norm (consistent with our analysis: the larger , the more this norm is penalized) and increase the test accuracy. The picture confirms the finding in Ghosh et al. (2023) (for the momentum parameter in momentum GD).

👁 Refer to caption

Figure 4: Resnet-50 on CIFAR-10 trained with full-batch Adam, , . As increases, the norm rises and the test accuracy falls. We train longer than necessary for near-perfect classification on the train dataset (at least 2-3 thousand epochs), and the test accuracies plotted here are maximal. The perturbed norms are also maximal after excluding the initial training period (i. e., the plotted “norms” are at peaks of the “hills” described in Section 5). All results are averaged across five runs with different initialization seeds. Additional evidence and more details are provided in Appendix H.

👁 Refer to caption

Figure 5: Resnet-50 on CIFAR-10 trained with full-batch Adam, , . The perturbed one-norm falls as increases, and the test accuracy rises. Both metrics are calculated as in Figure 4. All results are averaged across three runs with different initialization seeds.

Figure 6 shows the graphs of as functions of the epoch number. The “norm” decreases, then rises again, and then decreases further until it flatlines.222Note that the perturbed one-norm cannot be near-zero at the end of training because it is bounded from below by . Throughout most of the training, the larger the smaller the “norm”. The “hills” of the “norm” curves are higher with smaller and larger . This is consistent with our analysis because the larger compared to , the more is prevented from falling by the correction term.

👁 Refer to caption


Figure 6: Plots of after each epoch for full-batch Adam, . Resnet-50 on CIFAR-10, left: , right: .

6 Limitations and Future Directions

As far as we know, the assumption similar to (8) is explicitly or implicitly present in all previous work on backward error analysis of gradient-based machine learning algorithms. (Recently, Beneventano (2023) weakened this assumption for SGD without replacement, but their focus is somewhat different.) There is evidence that large-batch algorithms often operate near or at the edge of stability (Cohen et al., 2021, 2022), in which the largest eigenvalue of the hessian can be large, making it unclear whether the higher-order partial derivatives can safely be assumed bounded near optimality. In addition, as Smith et al. (2021) point out, in the mini-batch setting backward error analysis can be more accurate. We leave a qualitative analysis of the behavior of first-order terms in Theorem 3.1 in the mini-batch case as a future direction.

Relatedly, Adam does not always generalize worse than SGD: for transformers, Adam often outperforms (Zhang et al., 2020; Kumar et al., 2022). Moreover, for NLP tasks a long time can be spent training close to an interpolating solution. Our analysis suggests that in the latter regime the anti-regularization effect disappears, which does indeed confirm the finding that generalization can be better. However, we believe this explanation is not complete, and more work is needed to connect the implicit bias to the training dynamics of transformers.

In addition, the constant in Theorem 3.1 goes to infinity as goes to zero. Theoretically, our proof does not exclude the case where for very small the trajectory of the piecewise ODE is only close to the Adam trajectory for small, suboptimal learning rates, at least at later stages of learning. (For the initial learning period, this is not a problem.) It appears to also be true of Proposition 1 in Ma et al. (2022) (zeroth-order approximation by sign-GD). This is especially noticeable in the large-spike regime of training (see Section 5) which, despite being obviously unstable, can still lead to acceptable test errors. It would be worthwhile to investigate this regime in detail.

Acknowledgments

We extend our special thanks to Boris Hanin, Sam Smith, and the anonymous reviewers for their insightful comments and suggestions that greatly enhanced this work. We are also grateful to Jianqing Fan, Pier Beneventano, and Rae Yu for engaging and productive discussions. Cattaneo gratefully acknowledges financial support from the National Science Foundation through DMS-2210561 and SES-2241575. Klusowski gratefully acknowledges financial support from the National Science Foundation through CAREER DMS-2239448. Additionally, we acknowledge the Princeton Research Computing resources, coordinated by the Princeton Institute for Computational Science and Engineering (PICSciE) and the Office of Information Technology’s Research Computing.

Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.

References

  • Andriushchenko et al. (2023) Andriushchenko, M., Croce, F., Müller, M., Hein, M., and Flammarion, N. A modern look at the relationship between sharpness and generalization. In Proceedings of the 40th International Conference on Machine Learning, ICML’23. JMLR.org, 2023.
  • Balles et al. (2020) Balles, L., Pedregosa, F., and Roux, N. L. The geometry of sign gradient descent. arXiv preprint arXiv:2002.08056, 2020.
  • Barakat & Bianchi (2021) Barakat, A. and Bianchi, P. Convergence and dynamical behavior of the adam algorithm for nonconvex stochastic optimization. SIAM Journal on Optimization, 31(1):244–274, 2021.
  • Barrett & Dherin (2021) Barrett, D. and Dherin, B. Implicit gradient regularization. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=3q5IqUrkcF.
  • Beneventano (2023) Beneventano, P. On the trajectories of sgd without replacement. arXiv preprint arXiv:2312.16143, 2023.
  • Bernstein et al. (2018) Bernstein, J., Wang, Y.-X., Azizzadenesheli, K., and Anandkumar, A. signsgd: Compressed optimisation for non-convex problems. In International Conference on Machine Learning, pp. 560–569. PMLR, 2018.
  • Beyer et al. (2022) Beyer, L., Zhai, X., and Kolesnikov, A. Better plain vit baselines for imagenet-1k. arXiv preprint arXiv:2205.01580, 2022.
  • Chen et al. (2018) Chen, J., Zhou, D., Tang, Y., Yang, Z., Cao, Y., and Gu, Q. Closing the generalization gap of adaptive gradient methods in training deep neural networks. arXiv preprint arXiv:1806.06763, 2018. URL https://arxiv.org/pdf/1806.06763.
  • Chen et al. (2021) Chen, X., Hsieh, C.-J., and Gong, B. When vision transformers outperform resnets without pre-training or strong data augmentations. arXiv preprint arXiv:2106.01548, 2021.
  • Chowdhery et al. (2022) Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H. W., Sutton, C., Gehrmann, S., et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  • Cohen et al. (2021) Cohen, J., Kaur, S., Li, Y., Kolter, J. Z., and Talwalkar, A. Gradient descent on neural networks typically occurs at the edge of stability. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=jh-rTtvkGeM.
  • Cohen et al. (2022) Cohen, J. M., Ghorbani, B., Krishnan, S., Agarwal, N., Medapati, S., Badura, M., Suo, D., Cardoze, D., Nado, Z., Dahl, G. E., et al. Adaptive gradient methods at the edge of stability. arXiv preprint arXiv:2207.14484, 2022.
  • Dosovitskiy et al. (2020) Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • Ernst Hairer & Wanner (2006) Ernst Hairer, C. L. and Wanner, G. Geometric numerical integration. Springer-Verlag, Berlin, 2 edition, 2006. ISBN 3-540-30663-3.
  • Even et al. (2023) Even, M., Pesme, S., Gunasekar, S., and Flammarion, N. (s) gd over diagonal linear networks: Implicit regularisation, large stepsizes and edge of stability. arXiv preprint arXiv:2302.08982, 2023.
  • Farazmand (2020) Farazmand, M. Multiscale analysis of accelerated gradient methods. SIAM Journal on Optimization, 30(3):2337–2354, 2020.
  • Foret et al. (2021) Foret, P., Kleiner, A., Mobahi, H., and Neyshabur, B. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=6Tm1mposlrM.
  • França et al. (2021) França, G., Jordan, M. I., and Vidal, R. On dissipative symplectic integration with applications to gradient-based optimization. Journal of Statistical Mechanics: Theory and Experiment, 2021(4):043402, 2021.
  • Ghosh et al. (2023) Ghosh, A., Lyu, H., Zhang, X., and Wang, R. Implicit regularization in heavy-ball momentum accelerated stochastic gradient descent. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=ZzdBhtEH9yB.
  • Granziol (2020) Granziol, D. Flatness is a false friend. arXiv preprint arXiv:2006.09091, 2020.
  • Gunasekar et al. (2018a) Gunasekar, S., Lee, J., Soudry, D., and Srebro, N. Characterizing implicit bias in terms of optimization geometry. In International Conference on Machine Learning, pp. 1832–1841. PMLR, 2018a.
  • Gunasekar et al. (2018b) Gunasekar, S., Lee, J. D., Soudry, D., and Srebro, N. Implicit bias of gradient descent on linear convolutional networks. Advances in neural information processing systems, 31, 2018b.
  • He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  770–778, 2016.
  • Hoffer et al. (2017) Hoffer, E., Hubara, I., and Soudry, D. Train longer, generalize better: closing the generalization gap in large batch training of neural networks. Advances in neural information processing systems, 30, 2017.
  • Ji & Telgarsky (2018a) Ji, Z. and Telgarsky, M. Gradient descent aligns the layers of deep linear networks. arXiv preprint arXiv:1810.02032, 2018a.
  • Ji & Telgarsky (2018b) Ji, Z. and Telgarsky, M. Risk and parameter convergence of logistic regression. arXiv preprint arXiv:1803.07300, 2018b.
  • Ji & Telgarsky (2019) Ji, Z. and Telgarsky, M. The implicit bias of gradient descent on nonseparable data. In Conference on Learning Theory, pp.  1772–1798. PMLR, 2019.
  • Ji & Telgarsky (2020) Ji, Z. and Telgarsky, M. Directional convergence and alignment in deep learning. Advances in Neural Information Processing Systems, 33:17176–17186, 2020.
  • Jiang et al. (2022) Jiang, K., Malik, D., and Li, Y. How does adaptive optimization impact local neural network geometry? arXiv preprint arXiv:2211.02254, 2022.
  • Keskar & Socher (2017) Keskar, N. S. and Socher, R. Improving generalization performance by switching from adam to sgd. arXiv preprint arXiv:1712.07628, 2017.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
  • Kovachki & Stuart (2021) Kovachki, N. B. and Stuart, A. M. Continuous time analysis of momentum methods. Journal of Machine Learning Research, 22(17):1–40, 2021.
  • Kumar et al. (2022) Kumar, A., Shen, R., Bubeck, S., and Gunasekar, S. How to fine-tune vision models with sgd. arXiv preprint arXiv:2211.09359, 2022.
  • Kunin et al. (2020) Kunin, D., Sagastuy-Brena, J., Ganguli, S., Yamins, D. L., and Tanaka, H. Neural mechanics: Symmetry and broken conservation laws in deep learning dynamics. arXiv preprint arXiv:2012.04728, 2020.
  • Lee et al. (2015) Lee, C.-Y., Xie, S., Gallagher, P., Zhang, Z., and Tu, Z. Deeply-supervised nets. In Artificial intelligence and statistics, pp.  562–570. Pmlr, 2015.
  • Li et al. (2017) Li, Q., Tai, C., and E, W. Stochastic modified equations and adaptive stochastic gradient algorithms. In Precup, D. and Teh, Y. W. (eds.), Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pp.  2101–2110. PMLR, 8 2017. URL https://proceedings.mlr.press/v70/li17f.html.
  • Lyu & Li (2019) Lyu, K. and Li, J. Gradient descent maximizes the margin of homogeneous neural networks. arXiv preprint arXiv:1906.05890, 2019.
  • Ma et al. (2022) Ma, C., Wu, L., and Weinan, E. A qualitative study of the dynamic behavior for adaptive gradient algorithms. In Mathematical and Scientific Machine Learning, pp. 671–692. PMLR, 2022.
  • Malladi et al. (2022) Malladi, S., Lyu, K., Panigrahi, A., and Arora, S. On the sdes and scaling rules for adaptive gradient algorithms. Advances in Neural Information Processing Systems, 35:7697–7711, 2022.
  • Miyagawa (2022) Miyagawa, T. Toward equation of motion for deep neural networks: Continuous-time gradient descent and discretization error analysis. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https://openreview.net/forum?id=qq84D17BPu.
  • Molybog et al. (2023) Molybog, I., Albert, P., Chen, M., DeVito, Z., Esiobu, D., Goyal, N., Koura, P. S., Narang, S., Poulton, A., Silva, R., et al. A theory on adam instability in large-scale machine learning. arXiv preprint arXiv:2304.09871, 2023.
  • Nacson et al. (2019a) Nacson, M. S., Gunasekar, S., Lee, J., Srebro, N., and Soudry, D. Lexicographic and depth-sensitive margins in homogeneous and non-homogeneous deep models. In International Conference on Machine Learning, pp. 4683–4692. PMLR, 2019a.
  • Nacson et al. (2019b) Nacson, M. S., Lee, J., Gunasekar, S., Savarese, P. H. P., Srebro, N., and Soudry, D. Convergence of gradient descent on separable data. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  3420–3428. PMLR, 2019b.
  • Nacson et al. (2019c) Nacson, M. S., Srebro, N., and Soudry, D. Stochastic gradient descent on separable data: Exact convergence with a fixed learning rate. In The 22nd International Conference on Artificial Intelligence and Statistics, pp.  3051–3059. PMLR, 2019c.
  • Qian & Qian (2019) Qian, Q. and Qian, X. The implicit bias of adagrad on separable data. Advances in Neural Information Processing Systems, 32, 2019.
  • Rosca et al. (2021) Rosca, M. C., Wu, Y., Dherin, B., and Barrett, D. Discretization drift in two-player games. In International Conference on Machine Learning, pp. 9064–9074. PMLR, 2021.
  • Smith et al. (2021) Smith, S. L., Dherin, B., Barrett, D., and De, S. On the origin of implicit regularization in stochastic gradient descent. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=rq_Qr0c1Hyo.
  • Soudry et al. (2018) Soudry, D., Hoffer, E., Nacson, M. S., Gunasekar, S., and Srebro, N. The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research, 19(1):2822–2878, 2018.
  • Tieleman et al. (2012) Tieleman, T., Hinton, G., et al. Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2):26–31, 2012.
  • Wang et al. (2021) Wang, B., Meng, Q., Chen, W., and Liu, T.-Y. The implicit bias for adaptive optimization algorithms on homogeneous neural networks. In Meila, M. and Zhang, T. (eds.), Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pp.  10849–10858. PMLR, 7 2021. URL https://proceedings.mlr.press/v139/wang21q.html.
  • Wang et al. (2022) Wang, B., Meng, Q., Zhang, H., Sun, R., Chen, W., Ma, Z.-M., and Liu, T.-Y. Does momentum change the implicit regularization on separable data? Advances in Neural Information Processing Systems, 35:26764–26776, 2022.
  • Wilson et al. (2017) Wilson, A. C., Roelofs, R., Stern, M., Srebro, N., and Recht, B. The marginal value of adaptive gradient methods in machine learning. Advances in neural information processing systems, 30, 2017.
  • Woodworth et al. (2020) Woodworth, B., Gunasekar, S., Lee, J. D., Moroshko, E., Savarese, P., Golan, I., Soudry, D., and Srebro, N. Kernel and rich regimes in overparametrized models. In Conference on Learning Theory, pp.  3635–3673. PMLR, 2020.
  • Xie et al. (2022) Xie, Z., Wang, X., Zhang, H., Sato, I., and Sugiyama, M. Adaptive inertia: Disentangling the effects of adaptive learning rate and momentum. In International conference on machine learning, pp. 24430–24459. PMLR, 2022.
  • Zhang et al. (2020) Zhang, J., Karimireddy, S. P., Veit, A., Kim, S., Reddi, S., Kumar, S., and Sra, S. Why are adaptive methods good for attention models? Advances in Neural Information Processing Systems, 33:15383–15393, 2020.
  • Zhao et al. (2022) Zhao, Y., Zhang, H., and Hu, X. Penalizing gradient norm for efficiently improving generalization in deep learning. In International Conference on Machine Learning, pp. 26982–26992. PMLR, 2022.
  • Zhou et al. (2020) Zhou, P., Feng, J., Ma, C., Xiong, C., Hoi, S. C. H., et al. Towards theoretically understanding why sgd generalizes better than adam in deep learning. Advances in Neural Information Processing Systems, 33:21285–21296, 2020.

Appendix A Overview

The appendix provide some omitted details and proofs.

We consider two algorithms: RMSProp and Adam, and two versions of each algorithm (with the numerical stability parameter inside and outside of the square root in the denominator). This means there are four main theorems: Theorem B.4, Theorem C.4, Theorem D.4 and Theorem E.4, each residing in the section completely devoted to one algorithm. The simple induction argument taken from Ghosh et al. (2023), essentially the same for each of these theorems, is based on an auxiliary result whose corresponding versions are Theorem B.3, Theorem C.3, Theorem D.3 and Theorem E.3. The proof of this result is also elementary but long, and it is done by a series of lemmas in Appendix F and Appendix G. Out of these four, we only prove Theorem B.3 since the other three results are proven in the same way with obvious changes.

Appendix H contains some details about the numerical experiments.

A.1 Notation

We denote the loss of the th minibatch as a function of the network parameters by , and in the full-batch setting we omit the index and write . As usual, means the gradient of , and nabla with indices means partial derivatives, e. g. is a shortcut for .

The letter will always denote a finite time horizon of the ODEs, will always denote the training step size, and we will replace with when convenient, where is the step number. We will use the same notation for the iteration of the discrete algorithm , the piecewise ODE solution and some auxiliary terms for each of the four algorithms: see Definition B.1, Definition C.1, Definition D.1, Definition E.1. This way, we avoid cluttering the notation significantly. We are careful to reference the relevant definition in all theorem statements.

Appendix B RMSProp with Outside the Square Root

Definition B.1.

In this section, for some , , , let the sequence of -vectors be defined for by

(12)

Let be defined as a continuous solution to the piecewise ODE

(13)

for with the initial condition , where , and are -dimensional functions with components

Assumption B.2.
  1. 1.

    For some positive constants , , , we have

  2. 2.

    For some we have for all

    where is defined in Definition B.1.

Theorem B.3(RMSProp with outside: local error bound).

Suppose B.2 holds. Then for all ,

for a positive constant depending on .

The proof of Theorem B.3 is conceptually simple but very technical, and we delay it until Appendix G. For now assuming it as given and combining it with a simple induction argument gives a global error bound which follows.

Theorem B.4(RMSProp with outside: global error bound).

Suppose B.2 holds, and

for defined in Definition B.1. Then there exist positive constants , , such that for all

where . The constants can be defined as

Proof.

We will show this by induction over , the same way an analogous bound is shown in Ghosh et al. (2023).

The base case is . Indeed, . Then the th component of is

By Theorem B.3, the absolute value of the right-hand side does not exceed , which means . Since , the base case is proven.

Now suppose that for all the claim

is proven. Then

where \labrel1 is by the triangle inequality, \labrel2 is by , in \labrel3 we used for all .

Next, combining Theorem B.3 with (12), we have

(14)

where to simplify notation we put

Using , , we have

(15)

But since

we have

(16)

Combining (15) and (16), we obtain

(17)

where in \labrel1 we used the induction hypothesis and that the bound on is already proven.

Now note that since , we have , which is rewritten as

Then we can continue (17):

(18)

Again using , we conclude from (14) and (18) that

finishing the induction step. ∎

B.1 RMSProp with outside: full-batch

In the full-batch setting , the terms in (13) simplify to

If is small and the iteration number is large, (13) simplifies to

Appendix C RMSProp with Inside the Square Root

Definition C.1.

In this section, for some , , , let the sequence of -vectors be defined for by

(19)

Let be defined as a continuous solution to the piecewise ODE

(20)

for with the initial condition , where , and are -dimensional functions with components

(21)
Assumption C.2.

For some positive constants , , , we have

Theorem C.3(RMSProp with inside: local error bound).

Suppose C.2 holds. Then for all ,

for a positive constant depending on , where is defined in Definition C.1.

The argument is the same as for Theorem B.3.

Theorem C.4(RMSProp with inside: global error bound).

Suppose C.2 holds. Then there exist positive constants , , such that for all

where ; and are defined in Definition C.1. The constants can be defined as

The argument is the same as for Theorem B.4.

C.1 RMSProp with Inside: Full-Batch

In the full-batch setting , the terms in (20) simplify to

If the iteration number is large, (20) rapidly becomes

where

Appendix D Adam with Outside the Square Root

Definition D.1.

In this section, for some , , , let the sequence of -vectors be defined for by

or, rewriting,

(22)

Let be defined as a continuous solution to the piecewise ODE

(23)

for with the initial condition , where , , , , , are -dimensional functions with components

(24)
Assumption D.2.
  1. 1.

    For some positive constants , , , we have

  2. 2.

    For some we have for all

    where is defined in Definition D.1.

Theorem D.3(Adam with outside: local error bound).

Suppose D.2 holds. Then for all ,

for a positive constant depending on and .

The argument is the same as for Theorem B.3.

Theorem D.4(Adam with outside: global error bound).

Suppose D.2 holds, and

for defined in Definition D.1. Then there exist positive constants , , such that for all

where . The constants can be defined as

Proof.

Analogously to Theorem B.4, we will prove this by induction over .

The base case is . Indeed, . Then the th component of is

By Theorem D.3, the absolute value of the right-hand side does not exceed , which means . Since , the base case is proven.

Now suppose that for all the claim

is proven. Then

where \labrel1 is by the triangle inequality, \labrel2 is by , in \labrel3 we used for all .

Next, combining Theorem D.3 with (22), we have

(25)

where to simplify notation we put

Using , , we have

(26)

But since

we have

(27)

Similarly,

(28)

Combining (26), (27) and (28), we get

(29)

where in \labrel1 we used the induction hypothesis and that the bound on is already proven.

Now note that since , we have , which is rewritten as

By the same logic,

Then we can continue (29):

(30)

Again using , we conclude from (25) and (30) that

finishing the induction step. ∎

Appendix E Adam with Inside the Square Root

Definition E.1.

In this section, for some , , , let the sequence of -vectors be defined for by

(31)

Let be defined as a continuous solution to the piecewise ODE

(32)

for with the initial condition , where , , , , , are -dimensional functions with components

(33)
Assumption E.2.

For some positive constants , , , we have

Theorem E.3(Adam with inside: local error bound).

Suppose E.2 holds. Then for all ,

for a positive constant depending on and .

The argument is the same as for Theorem B.3.

Theorem E.4(Adam with inside: global error bound).

Suppose E.2 holds for defined in Definition E.1. Then there exist positive constants , , such that for all

where . The constants can be defined as

The argument is the same as for Theorem D.4.

Appendix F Bounding the Derivatives of the ODE Solution

Our first goal is to argue that the first derivative of is uniformly bounded in absolute value. To achieve this, we just need to bound all the terms on the right-hand side of the ODE (13).

Lemma F.1.

Suppose B.2 holds. Then for all

(34)
(35)

with constants , defined as follows:

Proof of Lemma F.1.

Both bounds are straightforward:

and

concluding the proof of Lemma F.1. ∎

Lemma F.2.

Suppose B.2 holds. Then the first derivative of is uniformly over and bounded in absolute value by some positive constant, say .

Proof.

This follows immediately from , (34), (35) and the definition of given in (13). ∎

Our next goal is to argue that the second derivative of is bounded in absolute value. For this, we need to bound the first derivatives of all the three additive terms on the right-hand side of (13).

Lemma F.3.

Suppose B.2 holds. Then for all , we have

(36)
(37)
(38)
(39)
(40)
(41)
(42)
(43)

with constants , , , , , , , , , , , defined as follows:

Proof of Lemma F.3.

We prove the inequalities one by one.

The bound (36) is straightforward:

The inequality (37) follows immediately from the fact that by (13) we have for

The bound (38) follows from the assumptions immediately.

We will prove (39) by bounding the two additive terms on the right-hand side of the equality

(44)

It is easily shown that the first term in (44) is bounded in absolute value by :

For the proof of (39), it is left to show that the second term in (44) is bounded in absolute value by .

To bound , we can use

By the Cauchy-Schwarz inequality applied twice,

Next, for any and

(45)

This gives

We have obtained

(46)

This gives a bound on the second term in (44):

concluding the proof of (39).

We will prove (40) by bounding the four terms in the expression

where

To bound Term1, use , giving

To bound Term2, use , giving

To bound Term3, use , giving

To bound Term4, use (45), giving

The proof of (40) is finished.

The inequality (41) is already proven in (46).

To prove (42), note that the bound (45) gives

(47)
(48)
(49)

Combining two bounds above, we have

We are ready to conclude

It is left to prove (43). Since

and, as we have already seen in the argument for (40),

we are ready to bound

The proof of Lemma F.3 is concluded. ∎

Lemma F.4.

Suppose B.2 holds. Then the second derivative of is uniformly over and bounded in absolute value by some positive constant, say .

Proof.

This follows from the definition of given in (13), and that the first derivatives of all three terms in (13) are bounded by Lemma F.3. ∎

Finally, we need to argue that the third derivative of is bounded in absolute value. To achieve this, we need to bound the second derivatives of the terms on the right-hand side of (13).

Lemma F.5.

Suppose B.2 holds. Then for all ,

(50)
(51)
(52)
(53)
(54)
(55)

with constants , , , , , defined as follows:

Proof of Lemma F.5.

We prove the inequalities one by one.

The proof of (50) is straightforward:

To prove (51), note that

giving by (47)

To prove (52), note that

giving by (45) and (51)

The bound (53) follows from  (45), (51) and

To justify (54), put temporarily , and use

combined with

To justify (55), put temporarily

and use

from which (55) follows.

The proof of Lemma F.5 is concluded. ∎

Lemma F.6.

Suppose B.2 holds. Then the third derivative of is uniformly over and bounded in absolute value by some positive constant, say .

Proof.

By (38), (46) and (55)

From the definition of , it means that its derivatives up to order two are bounded. Similarly, the same is true for .

It follows from (52) and its proof that the derivatives up to order two of

are also bounded.

These considerations give the boundedness of the second derivative of the term

in (13). The boundedness of the second derivatives of the other two terms is shown analogously. By (13) and since , this means

for some positive constant . ∎

Appendix G Proof of Theorem B.3

Our next objective is proving and identifying the constant in the equality

We will make some preparations and achieve this objective in Lemma G.5. Then we will conclude the proof of Theorem B.3.

Lemma G.1.

Suppose B.2 holds. Then for all , , we have

(56)
Proof.

(56) follows from the mean value theorem applied times. ∎

Lemma G.2.

In the setting of Lemma G.1, for any we have

Proof.

By the Taylor expansion of on the segment at on the left

Combining this with (37) gives

(57)

Now applying the mean-value theorem times, we have by (46)

and in particular

Combining this with (57), we conclude the proof of Lemma G.2. ∎

Lemma G.3.

In the setting of Lemma G.1,

Proof.

Fix .

Note that

where (G) is by Lemma G.2. ∎

Lemma G.4.

Suppose B.2 holds. Then for all ,

(58)

and

(59)

with and defined as follows:

Proof.

Note that

where (G) is by (56). Using the triangle inequality, we can conclude

(58) is proven.

We continue by showing

(60)

To prove this, use

with

and bounding

where (G) is by Lemma G.3. (60) is proven.

We turn to the proof of (59). By (60) and the triangle inequality

where

It is left to combine this with

This gives

where in (G) we used that . (59) is proven. ∎

Lemma G.5.

Suppose B.2 holds. Then

Proof.

Note that if , , we have

By the triangle inequality,

Apply this with

and use bounds

by Lemma G.4. ∎

We are finally ready to prove Theorem B.3.

Proof of Theorem B.3.

By (42) and (43), the first derivative of the function

is bounded in absolute value by a positive constant . By (13), this means

Combining this with

by Taylor expansion, we get

(61)

Using

with defined as

by (13), and calculating the derivative, it is easy to show

(62)

for a positive constant , where

From (61) and (62), by the triangle inequality

which, using (13), is rewritten as

It is left to combine this with Lemma G.5, giving the assertion of the theorem with

Appendix H Numerical Experiments

H.1 Models

We use small modifications of Resnet-50 and Resnet-101 implementations in the torchvision library for training on CIFAR-10 and CIFAR-100. The first convolution layer conv1 has kernel, stride 1 and “same” padding. Then comes batch normalization, and relu. Max pooling is removed, and otherwise conv2_x to conv5_x are as described in He et al. (2016) (see Table 1 there) except downsampling is performed by the middle convolution of each bottleneck block, as in version 1.5333https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch. After conv5 there is global average pooling and 10 or 100-way fully connected layer (for CIFAR-10 and CIFAR-100 respectively).

The MLP that we use for showing the closeness of trajectories in Figure 3 consists of two fully connected layers, each with 32 units and GeLU activation, followed by a fully-connected layer with 10 units.

In Figure 3, the curves called “first order” plot and the curves called “second order” plot , where is the Adam iteration defined in Definition 1.1 and

(63)

for and as defined in Section 3, with the same initial point .

H.2 Data Augmentation

We subtract the per-pixel mean and divide by standard deviation, and we use the data augmentation scheme from Lee et al. (2015), following He et al. (2016), section 4.2. During each pass over the training dataset, each initial image is padded evenly with zeros so that it becomes , then random crop is applied so that the picture becomes again, and random (probability ) horizontal (left to right) flip is used.

H.3 Experiment Details

In experiments whose results are reported in Figures 4 and 5 of the main paper, we train for a few thousand epochs and stop training when the train accuracy is near-perfect (Figure 11) and the testing accuracy does not significantly improve (Figure 12). Therefore, the maximal test accuracies are the final ones reached, and the maximal perturbed one-norms, after excluding the initial fall at the beginning of training, are at peaks of the “hills” on the norm curves (Figure 12).

Since the full dataset does not fit into GPU memory, we divide it into 100 “ghost batches” and accumulate the gradients before doing one optimization step. This means that we use ghost batch normalization (Hoffer et al., 2017) as opposed to full-dataset batch normalization, similarly to Cohen et al. (2021).

H.4 Additional Evidence

We provide evidence that the results in Figures 4 and 5 are robust to the change of architectures. In Figures 7 and 8, we show that the pictures are similar for a simple CNN created by the following code: {minted}[breaklines]python layers = [ # First block nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),

# Second block nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),

# Third block nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=’same’), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=’same’), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),

# Flatten and Dense layers nn.Flatten(), nn.Linear(in_features=128 * 4 * 4, out_features=512), nn.ReLU(), nn.Linear(in_features=512, out_features=num_classes), ] return nn.Sequential(*layers) In Figures 9 and 10, we show that the same conclusions can be made for a Vision Transformer (Dosovitskiy et al., 2020; Beyer et al., 2022). In these experiments, we use the SimpleViT architecture from the vit-pytorch library with patches, 6 transformer blocks with 16 heads, embedding size 512 and MLP dimension of 1024 (following Andriushchenko et al. (2023)).

👁 Refer to caption

Figure 7: A simple CNN trained on CIFAR-10 with full-batch Adam, , . As increases, the perturbed one-norm rises and the test accuracy falls. Both metrics are calculated as in Figures 4 and 5 of the main paper. All results are averaged across five runs with different initialization seeds.

👁 Refer to caption

Figure 8: A simple CNN trained on CIFAR-10 with full-batch Adam, , . The perturbed one-norm falls as increases, and the test accuracy rises. Both metrics are calculated as in Figures 4 and 5 of the main paper. All results are averaged across three runs with different initialization seeds.

👁 Refer to caption

Figure 9: A vision transformer trained on CIFAR-10 with full-batch Adam. The setting and conclusions are the same as in Figure 7.

👁 Refer to caption

Figure 10: A vision transformer trained on CIFAR-10 with full-batch Adam. The setting and conclusions are the same as in Figure 8.

👁 Refer to caption

Figure 11: Train loss and train accuracy curves for full-batch Adam, ResNet-50 on CIFAR-10, , , .

👁 Refer to caption

Figure 12: Test accuracy and after each epoch. The setting is the same as in Figure 11.

Appendix I Adam with Inside the Square Root: Informal Derivation

Our goal is to find such a trajectory that

Result I.1.

For we have

(64)
Derivation.

We take

for granted. Using this and the Taylor series, we can write

where in the last equality we just replaced with in the -term since it only affects higher-order terms. Now doing this again for step instead of step , we will have

where in the last equality we again replaced with since it only affects higher-order terms. Proceeding like this and adding the resulting equations, we have for , that

where we ignored the fact that is not bounded (we will get away with this because of exponential averaging). Hence, taking the square of this formal power series,

Summing up over , we have

which, using the expression for the inverse square root of a formal power series , gives us

Similarly,

We conclude

Result I.2.

For , the modified equation is (32).

Derivation.

Assume that the modified flow for satisfies where

By Taylor expansion, we have

(65)

Using Lemma I.1 and equating the terms before the corresponding powers of in (64) and (65), we obtain

(66)

It is left to find . Using

we have

Inserting this into (66) concludes the proof. ∎