If there is one thing the deep learning revolution has taught us, it's that neural nets will outperform hand-designed heuristics, given enough compute and data.

But we still use hand-designed heuristics to train our models. Let's replace our optimizers with trained neural nets!

If you are training models with < 5e8 parameters, for < 2e5 training steps, then with high probability this LEARNED OPTIMIZER will beat or match the tuned optimizer you are currently using, out of the box, with no hyperparameter tuning (!).

https://velo-code.github.io
https://arxiv.org/abs/2211.09760

Redirecting to https://github.com/google/learned_optimization/tree/main/learned_optimization/research/general_lopt

Meta-training learned optimizers is HARD. Each meta-training datapoint is an entire optimization task, so building a large meta-training dataset is HARD. Each of N meta-training steps can contain N training steps applying the learned optimizer -- so compute is also extreme (N^2).

And the resulting learned optimizer works really well! We reached out to other researchers inside Brain, and had them try it on their tasks, and subject to the scale constraints I mention above it did as well or better than what they were currently using, with no tuning.

Tasks include multiple vision models, multiple language models, decision transformers, distillation tasks, scientific modeling, and more.

Huge thanks to Luke Metz for leading this research direction for the last half dozen years (!!) -- it's wonderful to see it bear fruit.

Also, huge thanks to James Harrison who has completely taken over the project over the last several months, and is responsible for the careful analysis and coherent story in the paper, as well as the exciting ongoing work.

(Posting this *exclusive content* to Mastodon before Twitter :P You're winning by being here!)
@jascha Awesome work, thanks for sharing!

And huge thank you also to the other collaborators on this project -- Daniel Freeman, Amil Merchant, @lb @jekbradbury , Naman Agarwal, Ben Poole, Igor Mordatch, and Adam Roberts.

(and if any of you are on Mastodon, but I missed looking up your username -- very sorry, and please reply and claim credit!)

@jascha I can confirm that this is no exaggeration.

was really impressed by its performance: on my first attempt, with no tuning at all, it matched our heavily tuned Adam/AdaFactor setup.

It only fell behind for ViT-H which has more than half-billion parameters, is extremely hard to train (see Kaiming's papers for independent confirmation of this) and is waaaaay past OOD.

It also worked great on distillation, which it wasn't even trained on.

@lb @jascha I think it is really impressive. Is there any way to cooperate with this resaerch group?
@jascha Sounds very cool. How big is the overhead of running this vs 'heuristic' optimizers? Ie is this only a gain when training large models?
@EmilevanKrieken Overhead is relatively small in an absolute sense. It's about 10x the overhead of Adam, which is small compared to the cost of computing the gradient, for reasonable scale problems, trained with a reasonable minibatch size. See leftmost pane in this plot:
@jascha
That's very impressive, thank you!
@jascha Exciting work. Can't wait to try it
@jascha I've been really curious to try learned optimizers out of the box! Are there any good PyTorch libraries with pretrained optimizers available?
@rbhar90 Nope -- just in JAX for now! [insert evangelizing about how great JAX is here]
@jascha @rbhar90 I'm somewhat keen to roll a PyTorch version. (I've got a straightforward 1D audio task it may be suitable for.) On a quick inspection of the method (haven't looked at the code yet) it seems like it may be a little unpleasant with the overhead for the tiny MLP (might be able to TouchScript it)... But otherwise seems like smooth sailing?
@Sdatkinson @rbhar90 1) That's awesome!! 2) My lack of familiarity with PyTorch is fairly profound. :P From a JAX perspective, I will say that compiling the MLP is important, as is batching the input to the MLP across all parameter scalars in a tensor (same MLP is applied to all scalars in a weight tensor). 3) I would sanity check any port by comparing resulting optimization trajectories for a toy task with multiple parameter tensors. 4) Again, awesome!!