porting some code to sample audio triplets, cross fade & generate spectrograms from a mix of numpy/tensorflow to jax.
being able to move the batching all the way out with vmap makes sooo much less code.. it'll be even cleaner when it's bolted onto the training loop
@mat_kelcey out of curiosity, how to you write the doctstring of a function meant to be vmap'ed, in particular wrt the expected input shape and meaning of the input axis? I would expect to include the batch size to make it user friendly but it might be confusing for people reading the source code of the function.
@ogrisel good question, not sure. jax is all about function composition, so i wonder if there are patterns from languages like haskell to follow? i commonly find i write a collection of non batched pieces that "get assembled" in a short bit of code with multiple vmaps, pmaps etc so at least the more complex composition is in one place