Skip to main content

Paper: Flat minima vs sharp minima (Keskar et al., 2017)

 The paper "On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima" by Keskar et al. (2017) investigates the relationship between batch size, optimization landscapes, and generalization in deep learning. A key contribution is the flat minima vs. sharp minima hypothesis, which explains why smaller batches often generalize better than larger ones.


Key Contributions of the Paper

  1. Generalization Gap Observation:

    • Large-batch training (e.g., batch size = 8192) often leads to worse test accuracy compared to small-batch training (e.g., batch size = 32), even when both achieve similar training loss.

    • This is called the "generalization gap."

  2. Flat vs. Sharp Minima Hypothesis:

    • Small-batch methods tend to converge to flat minima (wide, low-curvature regions in the loss landscape).

    • Large-batch methods tend to find sharp minima (narrow, high-curvature regions).

    • Flat minima generalize better because they are more robust to input perturbations and noise.

  3. Empirical Evidence:

    • Measures the "sharpness" of minima using Hessian eigenvalues (larger eigenvalues → sharper minima).

    • Shows that large-batch SGD converges to sharper minima, while small-batch SGD finds flatter ones.


Why Flat Minima Generalize Better?

  • Robustness to Perturbations:

    • A flat minimum remains low even if parameters are slightly perturbed (good for unseen data).

    • A sharp minimum can spike in loss with small changes (overfits to training data).

  • Implicit Regularization:

    • Small-batch SGD has more noise (due to frequent updates), which helps escape sharp minima.

    • Large-batch SGD is more deterministic and gets stuck in sharp basins.

  • Connection to Bayesian Inference:

    • Flat minima can be seen as high-probability regions in the posterior distribution of parameters.


Experimental Results

  • Tested on ResNet, LSTM, and MLPs (CIFAR-10, PTB, MNIST).

  • Findings:

    • Large batches (sharp minima):

      • Train fast but generalize poorly.

      • Higher Hessian eigenvalues (more curvature).

    • Small batches (flat minima):

      • Train slower but generalize better.

      • Lower Hessian eigenvalues (smoother loss landscape).


Proposed Solutions for Large-Batch Training

Since large batches are desirable for parallel training, the paper suggests:

  1. Warmup + Gradual Batch Increase (avoid early sharp minima).

  2. Conservative Learning Rates (prevents aggressive convergence to sharp minima).

  3. Sharpness-Aware Minimization (SAM) (later work, e.g., Foret et al., 2021).


Criticisms & Later Work

  • Sharpness Definitions Matter: Some argue that "flatness" should be measured relative to parameter re-scaling (Dinh et al., 2017).

  • Not All Flat Minima Generalize: Some flat minima can still overfit (e.g., if they span too many dimensions).

  • Alternative Views: Some papers suggest "volume of minima" (not just flatness) matters (Baldassi et al., 2020).


Code Example: Measuring Sharpness

python
Copy
import torch
import torch.nn.functional as F

def compute_sharpness(model, dataset, epsilon=1e-3):
    """Estimate sharpness via random perturbations."""
    original_loss = compute_loss(model, dataset)
    max_sharpness = 0.0
    
    for param in model.parameters():
        perturbation = torch.randn_like(param) * epsilon
        param.data += perturbation
        perturbed_loss = compute_loss(model, dataset)
        sharpness = abs(perturbed_loss - original_loss) / epsilon
        max_sharpness = max(max_sharpness, sharpness)
        param.data -= perturbation  # reset
        
    return max_sharpness

Key Takeaways

  1. Small batches → Flat minima → Better generalization.

  2. Large batches → Sharp minima → Faster training but worse generalization.

  3. Solutions: Warmup, learning rate tuning, SAM optimizer.

This paper laid the foundation for understanding optimization dynamics in deep learning and inspired later work on sharpness-aware training. Would you like a deeper dive into SAM (Sharpness-Aware Minimization)? 😊

Link: https://arxiv.org/pdf/1609.04836

Comments

Popular posts from this blog

Simple Linear Regression - and Related Regression Loss Functions

Today's Topics: a. Regression Algorithms  b. Outliers - Explained in Simple Terms c. Common Regression Metrics Explained d. Overfitting and Underfitting e. How are Linear and Non Linear Regression Algorithms used in Neural Networks [Future study topics] Regression Algorithms Regression algorithms are a category of machine learning methods used to predict a continuous numerical value. Linear regression is a simple, powerful, and interpretable algorithm for this type of problem. Quick Example: These are the scores of students vs. the hours they spent studying. Looking at this dataset of student scores and their corresponding study hours, can we determine what score someone might achieve after studying for a random number of hours? Example: From the graph, we can estimate that 4 hours of daily study would result in a score near 80. It is a simple example, but for more complex tasks the underlying concept will be similar. If you understand this graph, you will understand this blog. Sim...

What problems can AI Neural Networks solve

How does AI Neural Networks solve Problems? What problems can AI Neural Networks solve? Based on effectiveness and common usage, here's the ranking from best to least suitable for neural networks (Classification Problems, Regression Problems and Optimization Problems.) But first some Math, background and related topics as how the Neural Network Learn by training (Supervised Learning and Unsupervised Learning.)  Background Note - Mathematical Precision vs. Practical AI Solutions. Math can solve all these problems with very accurate results. While Math can theoretically solve classification, regression, and optimization problems with perfect accuracy, such calculations often require impractical amounts of time—hours, days, or even years for complex real-world scenarios. In practice, we rarely need absolute precision; instead, we need actionable results quickly enough to make timely decisions. Neural networks excel at this trade-off, providing "good enough" solutions in seco...

Activation Functions in Neural Networks

  A Guide to Activation Functions in Neural Networks 🧠 Question: Without activation function can a neural network with many layers be non-linear? Answer: Provided at the end of this document. Activation functions are a crucial component of neural networks. Their primary purpose is to introduce non-linearity , which allows the network to learn the complex, winding patterns found in real-world data. Without them, a neural network, no matter how deep, would just be a simple linear model. In the diagram below the f is the activation function that receives input and send output to next layers. Commonly used activation functions. 1. Sigmoid Function 2. Tanh (Hyperbolic Tangent) 3. ReLU (Rectified Linear Unit - Like an Electronic Diode) 4. Leaky ReLU & PReLU 5. ELU (Exponential Linear Unit) 6. Softmax 7. GELU, Swish, and SiLU 1. Sigmoid Function                       The classic "S-curve," Sigmoid squashes any input value t...