Back to Blog
Blog

Edge of (Stochastic) Stability made simple — Part II: the mini-batch case

Feb 23, 2026
Pierfrancesco Beneventano

Part I: full-batch EOS · Part III: practical implications

Based on Edge of Stochastic Stability: Revisiting the Edge of Stability for SGD (Andreyev & Beneventano, arXiv:2412.20553); correspondence to Pierfrancesco Beneventano, pierb@mit.edu.



Acknowledgements / lineage.
Part I leaned heavily on the full-batch EOS story developed by Jeremy Cohen, Alex Damian, and collaborators (and their excellent Central Flows writeups). Part II is where we pivot to the mini-batch world, following our paper with Arseniy Andreyev.

Conceptual map (where this is going):

  • Part I: full-batch EOS: why “wrong step sizes” can still train, and the two mechanisms (local instability + progressive sharpening).
  • Part II (this post): SGD: why the Part I diagnostic fails, what “edge” should mean for stochastic dynamics, and Edge of Stochastic Stability (EoSS).
  • Part III: why any of this matters: batch-size effects, hyperparameters, flatness, and why common “SGD ≈ GD + noise” models can miss the point.

Part II: the mini-batch case

In Part I we had one landscape and a deterministic update.
Now we have a distribution of mini-batch landscapes and a stochastic update.

So the core question becomes:

What does it even mean to be “at the edge” when the landscape moves under your feet?


0) The puzzle (why the Part I diagnostic fails)

In full-batch GD, a crisp story holds:

  • the loss starts oscillating when you cross a curvature stability boundary;
  • sharpness rises, then self-regulates near a threshold like 2/η2/\eta.

In SGD, two empirical facts collide with that story:

  • Oscillations are cheap: SGD loss wiggles even on convex quadratics with stable step sizes.
  • λmax(2L)\lambda_{\max}(\nabla^2 L) is the wrong gauge: the full-batch top eigenvalue often sits below 2/η2/\eta and depends strongly on batch size bb, even though training still has edge-like “events”.

So: if oscillations are always present, and λmax\lambda_{\max} is not the saturating quantity…
what is the actual edge?


Notation (click to expand)
  • Dataset: {zi}i=1N\{z_i\}_{i=1}^N
  • Full loss: L(θ)=1Ni=1N(θ;zi)L(\theta)=\frac1N\sum_{i=1}^N \ell(\theta;z_i)
  • Mini-batch B{1,,N}B\subset\{1,\dots,N\}, B=b|B|=b
  • Mini-batch loss: LB(θ)=1biB(θ;zi)L_B(\theta)=\frac1b\sum_{i\in B} \ell(\theta;z_i)
  • SGD: θt+1=θtηLBt(θt)\theta_{t+1}=\theta_t-\eta\,\nabla L_{B_t}(\theta_t)

1) One step, two kinds of “noise”: gradient noise vs curvature noise

A common mental model is: “SGD = GD + gradient noise”.

That captures one randomness source (the gradient estimator), but misses another:

  • each step also sees a different mini-batch Hessian 2LBt(θt)\nabla^2 L_{B_t}(\theta_t).

So the geometry is genuinely stochastic: both directions and curvatures fluctuate.


2) Oscillations are not diagnostic — because there are two kinds

Here is the clean conceptual separation that makes the rest of the story click:

Type-1 oscillations (stable wobbling)

These are oscillations that exist even when the dynamics is stable. They are driven by gradient noise with a non-vanishing step size. They can look dramatic in the loss, but they do not imply edge-like instability.

Type-2 oscillations (edge-like, curvature-driven)

These are oscillations associated with saturating a genuine instability boundary on a local quadratic model. They are the stochastic analogue of EOS behavior: if you push η\eta up (or bb down) a little, you can trigger a runaway “catapult”.


3) A falsifiable definition of “the edge” for stochastic dynamics

If we want an EOS-like concept for SGD, we should demand:

  1. a criterion that actually implies instability (on the local quadratic approximation), and
  2. a quantity that empirically saturates during training.

This leads to a simple philosophy:

Edge = saturation of a valid instability criterion (for the dynamics you run).

And the key operational tool is:

The checkpoint-level edge test (restart + tiny destabilization)

  1. Train normally to a checkpoint θt\theta_t with hyperparameters (η,b)(\eta,b).
  2. Restart from θt\theta_t with a slightly more aggressive setting, e.g.
    • η\eta \uparrow (increase learning rate), or
    • bb \downarrow (decrease batch size).
  3. Compare the continuation curves.

If you truly were saturating a tight instability boundary, the perturbed run should exhibit a catapult: a sudden excursion + loss spike that the baseline does not.

# Edge test (checkpoint-level)

train with (η, b) → checkpoint θ_t

baseline: continue from θ_t with (η, b)
perturb:  continue from θ_t with (η', b') where η' > η or b' < b (small change)

if baseline is stable but perturb catapults:
    evidence: θ_t was near a tight instability boundary for SGD
else:
    whatever you measured is probably not a valid instability criterion

4) The tempting trap: a quantity that certifies oscillations (but not instability)

You can derive (via a second-order expansion of the full loss along an SGD step) a scalar that often sits near 2/η2/\eta and correlates with oscillatory behavior.

That scalar is sometimes called Gradient–Noise Interaction (GNI), and it is genuinely useful — but for a different purpose:

  • GNI ≈ 2/η2/\eta is a certificate that the loss is oscillatory (Type-1 oscillations can already do this).
  • It is not a certificate that you are near a curvature instability boundary.

What goes wrong if you use GNI as the “edge”

  • GNI can hit 2/η2/\eta quickly, without any progressive sharpening.
  • If you increase η\eta slightly while still in a stable Type-1 regime, you may just get a bigger wobble — no catapult.

So: GNI is an oscillation meter, not an instability boundary.


5) The quantity SGD actually “feels”: directional curvature of the mini-batch step

SGD steps along gB(θ):=LB(θ)g_B(\theta):=\nabla L_B(\theta).
So the curvature that matters on that step is the Rayleigh quotient of the mini-batch Hessian along that mini-batch gradient:

StepSharpnessB(θ):=gB(θ)2LB(θ)gB(θ)gB(θ)2.\mathrm{StepSharpness}_B(\theta) :=\frac{g_B(\theta)^\top \nabla^2 L_B(\theta)\, g_B(\theta)}{\|g_B(\theta)\|^2}.

Now average over batch sampling:

BatchSharpness(θ):=EB ⁣[LB(θ)2LB(θ)LB(θ)LB(θ)2].\textbf{BatchSharpness}(\theta) :=\mathbb{E}_{B}\!\left[ \frac{\nabla L_B(\theta)^\top \nabla^2 L_B(\theta)\nabla L_B(\theta)} {\|\nabla L_B(\theta)\|^2} \right].

6) Edge of Stochastic Stability (EoSS): what actually saturates at 2/η2/\eta

Here is the empirical headline:

And this resolves the earlier puzzle:

  • The full-batch λmax(2L(θ))\lambda_{\max}(\nabla^2 L(\theta)) typically plateaus below 2/η2/\eta, and the plateau depends on batch size.
  • Yet Batch Sharpness is the quantity that locks to 2/η2/\eta across batch sizes.

A sharper (and subtle) consequence: what happens to λmax\lambda_{\max} once EoSS is reached

Empirically, a phase transition occurs:

  • λmax\lambda_{\max} increases only as long as Batch Sharpness increases.
  • Once Batch Sharpness plateaus at 2/η2/\eta, λmax\lambda_{\max} stops increasing (and if it moves, it tends to decrease).
  • If you change hyperparameters mid-training, the “location” (hence λmax\lambda_{\max}) can be path-dependent — Batch Sharpness reacts immediately and then re-saturates.

Cheat-sheet: four “sharpness-like” quantities (click to expand)
  • Full-batch sharpness: λmax(2L)\lambda_{\max}(\nabla^2 L)
  • IAS (Interaction-Aware Sharpness): directional curvature of the full-batch Hessian along mini-batch gradient directions
  • GNI: a directional “oscillation certifier” from a second-order loss expansion (tracks Type-1 oscillations)
  • Batch Sharpness: directional curvature of mini-batch Hessians along their own gradients (the EoSS object)

EoSS is specifically the statement about Batch Sharpness.


7) Catapults: why stochasticity makes the edge feel “spiky”

In full-batch EOS, the stabilizing mechanism is deterministic.

In EoSS, what stabilizes is an expectation. So per-step randomness matters:

  • most steps hover near-threshold;
  • occasionally, you sample a “streak” of unusually sharp mini-batches;
  • those steps overshoot, triggering a catapult: a large excursion + a loss spike.

Then, often:

  • the trajectory re-enters a region where Batch Sharpness is below 2/η2/\eta,
  • progressive sharpening resumes,
  • and the process returns to the hovering regime.

8) A short theory interlude: why 2/η2/\eta shows up again (but differently)

In Part I, 2/η2/\eta came from stability of IηHI-\eta H on a quadratic.

In SGD, 2/η2/\eta reappears because the same second-order algebra applies along the mini-batch step direction. The local instability certificate becomes directional and batch-dependent.

One blog-friendly way to remember it:

(We keep the proof details in the paper; the important point here is: this is an instability criterion, not just an oscillation meter.)


9) What this changes conceptually (previewing Part III)

EoSS is not just “one more sharpness metric”. It forces a conceptual reframe:

  • It’s not about minima: stabilization can happen in regions with no stationary points.
  • Location becomes distributional: what matters is the distribution of mini-batch Hessians, not only the mean/full-batch Hessian.
  • “Flatness” becomes stability-grounded: curvature proxies only matter insofar as they relate to Batch Sharpness along the trajectory.

And it also clarifies why some common proxies for SGD can fail:

  • If the governing object is batch-dependent curvature, then “GD + generic noise” (or standard SDE models) can be a different dynamical system unless they preserve the relevant mini-batch geometry.

Part III is where we make this fully practical: hyperparameter tuning, diagnostics, and what modeling assumptions break.

Next: Part III: Getting practical


10) Summary (the things to remember)


Appendix: measuring Batch Sharpness in practice (optional)

Goal: estimate

EB[gBHBgBgB2]wheregB=LB(θ),  HB=2LB(θ).\mathbb{E}_{B}\left[ \frac{g_B^\top H_B g_B}{\|g_B\|^2} \right] \quad\text{where}\quad g_B=\nabla L_B(\theta),\ \ H_B=\nabla^2 L_B(\theta).

Practical Monte Carlo estimator (no explicit Hessian needed):

# Batch Sharpness estimator at θ
input: θ, sampler for batches B ~ P_b, num_samples M

vals = []
for m in 1..M:
    B  = sample_batch()
    g  = grad(L_B(θ))
    Hg = hvp(L_B, θ, v=g)          # Hessian-vector product
    vals.append( dot(g, Hg) / dot(g, g) )

return mean(vals)

Tip: if you want a quick-and-dirty “edge test” for your own run:

  • log Batch Sharpness over training,
  • checkpoint near saturation,
  • restart with a small η\eta\uparrow or bb\downarrow,
  • check if the perturbed continuation catapults.

References / further reading