Abstract, We propose KOALA++, a scalable Kalman-based optimization algorithm that explicitly models structured gradient uncertainty in neural network training. Unlike second-order methods, which rely on expensive second order gradient calculation, our method directly estimates the parameter covariance matrix by recursively updating compact gradient covariance products. This design improves upon the original KOALA framework that assumed diagonal covariance by implicitly capturing richer uncertainty structure without storing the full covariance matrix and avoiding large matrix inversions. Across diverse tasks, including image classification and language modeling, KOALA++ achieves accuracy on par or better than state-of-the-art first- and second-order optimizers while maintaining the efficiency of first-order methods.
This image was generated by ChatGPT.
It is included here only for aesthetic purposes in the project structure
and does not have any functional relation to the project itself.
Official implementation of KOALA++, a scalable Kalman-based optimization algorithm for neural network training.
📢 The code will be released after the acceptance of our paper.
- [2025-09] 🎉 Our paper has been accepted at NeurIPS 2025!
- 🔜 Code release is coming soon — stay tuned!
- 💡 Methodology and Key Insights
- 🌟 Getting Started
- 🖼 Task 1: Image Classification
- 🧠 Task 2: Language Modeling
- 📜 License
- 📖 Citation
- ✉️ Contact
KOALA++ extends the Kalman filtering view of optimization by explicitly propagating structured gradient uncertainty. Its core innovation lies in tracking a directional covariance surrogate:
instead of the full covariance
and
This surrogate captures anisotropic uncertainty while keeping memory and computational cost comparable to first-order optimizers.
Expanding the recursion for
KOALA++ resolves this by approximating
To get started, first clone the AdaFisher benchmark repository and set up the environment as described in their instructions:
git clone https://github.com/AtlasAnalyticsLab/AdaFisher.git
cd AdaFisher
# Follow their README to install the required dependenciesNavigate to the Task1_Image_Classification directory. This task supports training on both CIFAR-10 and CIFAR-100 datasets.
- To train on CIFAR-10:
bash train_cifar10.sh- To train on CIFAR-100:
bash train_cifar100.shKOALA++ differs from standard optimizers in that it performs a two-step update:
predict()before the forward/backward pass,update(loss_mean, loss_var)after the backward pass.
Here is an example integration into a PyTorch training loop:
# Initialize KOALA++ optimizer
optimizer = KOALAPlusPlus(
params=model.parameters(),
sigma=sigma, q=q, r=None, alpha_r=0.9,
weight_decay=0.0005, lr=lr
for i, (inputs, targets) in enumerate(train_loader):
# Measure data loading time
data_time.update(time.time() - end)
inputs = inputs.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)
# --- KOALA++ prediction step ---
optimizer.predict()
# Forward + compute loss
outputs = model(inputs)
loss = criterion(outputs, targets)
loss_mean = loss.mean()
# Backward
optimizer.zero_grad()
loss_mean.backward()
# --- KOALA++ update step ---
loss_var = torch.mean(loss.pow(2)) # or variance depending on implementation
optimizer.update(loss_mean, loss_var)- You can modify the optimizer, learning rate, and other hyperparameters directly within the respective
.shscript files. - All optimizers from AdaFisher (e.g., AdaFisher, SGD, Adam, etc.) are supported.
Navigate to the Task2_Language_Model directory.
Simply run the corresponding training script to begin training your language model:
bash train_language_model.shThe script will use the configuration set inside to launch the training procedure, and you can modify the script for different optimizers or hyperparameter settings.
This project is licensed under the GNU General Public License v3.0 - see the LICENSE file for details.
If you find this work useful, please cite our paper:
@inproceedings{xiakoala++,
title={KOALA++: Efficient Kalman-Based Optimization with Gradient-Covariance Products},
author={Xia, Zixuan and Davtyan, Aram and Favaro, Paolo},
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}
}
For questions or collaboration inquiries, please reach out:
Zixuan Xia — [email protected] · [email protected]
⭐️ We appreciate your interest in KOALA++ and look forward to sharing the code and results with the community.

