I drafted an implementation of Cyclical SGLD using Blackjax and Optax.
As you can see 👇 Cyclical SGLD, alternating exploration and sampling phases, is much better on multi-modal targets than vanilla SGLD. Next step: CIFAR-10 with a Bayesian Resnet18.
