Skip to content

Official Implementation of KOALA++: Efficient Kalman-Based Optimization of Neural Networks with Gradient-Covariance Products

Notifications You must be signed in to change notification settings

Sumxiaa/KOALA_Plus_Plus

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🐨 KOALA++: Efficient Kalman-Based Optimization with Gradient-Covariance Products

Abstract, We propose KOALA++, a scalable Kalman-based optimization algorithm that explicitly models structured gradient uncertainty in neural network training. Unlike second-order methods, which rely on expensive second order gradient calculation, our method directly estimates the parameter covariance matrix by recursively updating compact gradient covariance products. This design improves upon the original KOALA framework that assumed diagonal covariance by implicitly capturing richer uncertainty structure without storing the full covariance matrix and avoiding large matrix inversions. Across diverse tasks, including image classification and language modeling, KOALA++ achieves accuracy on par or better than state-of-the-art first- and second-order optimizers while maintaining the efficiency of first-order methods.

OpenReview arXiv
1Work done during Master studies at the University of Bern  2Computer Vision Group, University of Bern 

Overview of Project

⚠️ Note
This image was generated by ChatGPT.
It is included here only for aesthetic purposes in the project structure
and does not have any functional relation to the project itself.

Official implementation of KOALA++, a scalable Kalman-based optimization algorithm for neural network training.
📢 The code will be released after the acceptance of our paper.


📰 Latest News

  • [2025-09] 🎉 Our paper has been accepted at NeurIPS 2025!
  • 🔜 Code release is coming soon — stay tuned!

📁 Repository Contents


💡 Methodology and Key Insights

KOALA++ extends the Kalman filtering view of optimization by explicitly propagating structured gradient uncertainty. Its core innovation lies in tracking a directional covariance surrogate:

$$ v_k ;:=; H_k, P_{k-1} ;\in; \mathbb{R}^{1\times n}, $$

instead of the full covariance $P_{k-1}\in\mathbb{R}^{n\times n}$. Here $H_k \in \mathbb{R}^{1\times n}$ and $v_k \in \mathbb{R}^{1\times n}$ are row vectors,
and $Q, R \in \mathbb{R}$ are scalars.
This surrogate captures anisotropic uncertainty while keeping memory and computational cost comparable to first-order optimizers.

Expanding the recursion for $v_k = H_k P_{k-1}$ yields a term $H_k P_{k-2}$, which is not directly computable since $P_{k-2}$ is not stored.
KOALA++ resolves this by approximating $P_{k-2}$ via a least-squares problem with constraint $v_{k-1} = H_{k-1} P_{k-2}$.


Algorithm Summary

Algorithm of KOALA++


🌟 Getting Started

To get started, first clone the AdaFisher benchmark repository and set up the environment as described in their instructions:

git clone https://github.com/AtlasAnalyticsLab/AdaFisher.git
cd AdaFisher
# Follow their README to install the required dependencies

🖼 Task 1: Image Classification

Navigate to the Task1_Image_Classification directory. This task supports training on both CIFAR-10 and CIFAR-100 datasets.

Run Training

  • To train on CIFAR-10:
bash train_cifar10.sh
  • To train on CIFAR-100:
bash train_cifar100.sh

🔗 Integrating KOALA++ as an Optimizer

KOALA++ differs from standard optimizers in that it performs a two-step update:

  • predict() before the forward/backward pass,
  • update(loss_mean, loss_var) after the backward pass.

Here is an example integration into a PyTorch training loop:

# Initialize KOALA++ optimizer
optimizer = KOALAPlusPlus(
            params=model.parameters(),
            sigma=sigma, q=q, r=None, alpha_r=0.9,
            weight_decay=0.0005, lr=lr

for i, (inputs, targets) in enumerate(train_loader):
    # Measure data loading time
    data_time.update(time.time() - end)

    inputs  = inputs.cuda(non_blocking=True)
    targets = targets.cuda(non_blocking=True)

    # --- KOALA++ prediction step ---
    optimizer.predict()

    # Forward + compute loss
    outputs   = model(inputs)
    loss      = criterion(outputs, targets)
    loss_mean = loss.mean()

    # Backward
    optimizer.zero_grad()
    loss_mean.backward()

    # --- KOALA++ update step ---
    loss_var = torch.mean(loss.pow(2))   # or variance depending on implementation
    optimizer.update(loss_mean, loss_var)

Notes

  • You can modify the optimizer, learning rate, and other hyperparameters directly within the respective .sh script files.
  • All optimizers from AdaFisher (e.g., AdaFisher, SGD, Adam, etc.) are supported.

🧠 Task 2: Language Modeling

Navigate to the Task2_Language_Model directory.

Run Training

Simply run the corresponding training script to begin training your language model:

bash train_language_model.sh

The script will use the configuration set inside to launch the training procedure, and you can modify the script for different optimizers or hyperparameter settings.


📜 License

This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details.

📖 Citation

If you find this work useful, please cite our paper:

@inproceedings{xiakoala++,
  title={KOALA++: Efficient Kalman-Based Optimization with Gradient-Covariance Products},
  author={Xia, Zixuan and Davtyan, Aram and Favaro, Paolo},
  booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}
}

✉️ Contact

For questions or collaboration inquiries, please reach out:
Zixuan Xia[email protected] · [email protected]


⭐️ We appreciate your interest in KOALA++ and look forward to sharing the code and results with the community.

About

Official Implementation of KOALA++: Efficient Kalman-Based Optimization of Neural Networks with Gradient-Covariance Products

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published