If there is one thing the deep learning revolution has taught us, it's that neural nets will outperform hand-designed heuristics, given enough compute and data.

But we still use hand-designed heuristics to train our models. Let's replace our optimizers with trained neural nets!
@jascha I've been really curious to try learned optimizers out of the box! Are there any good PyTorch libraries with pretrained optimizers available?
@rbhar90 Nope -- just in JAX for now! [insert evangelizing about how great JAX is here]
@jascha @rbhar90 I'm somewhat keen to roll a PyTorch version. (I've got a straightforward 1D audio task it may be suitable for.) On a quick inspection of the method (haven't looked at the code yet) it seems like it may be a little unpleasant with the overhead for the tiny MLP (might be able to TouchScript it)... But otherwise seems like smooth sailing?
@Sdatkinson @rbhar90 1) That's awesome!! 2) My lack of familiarity with PyTorch is fairly profound. :P From a JAX perspective, I will say that compiling the MLP is important, as is batching the input to the MLP across all parameter scalars in a tensor (same MLP is applied to all scalars in a weight tensor). 3) I would sanity check any port by comparing resulting optimization trajectories for a toy task with multiple parameter tensors. 4) Again, awesome!!