Hmm. I switched to #pytorch for implementing a quick thing. It turns out it's not possible to functorch.vmap() a function that built on torch.autograd.Function. Unlike #jax, this makes it hard to vectorize straight-through estimators.