porting some code to sample audio triplets, cross fade & generate spectrograms from a mix of numpy/tensorflow to jax.
being able to move the batching all the way out with vmap makes sooo much less code.. it'll be even cleaner when it's bolted onto the training loop
@mat_kelcey out of curiosity, how to you write the doctstring of a function meant to be vmap'ed, in particular wrt the expected input shape and meaning of the input axis? I would expect to include the batch size to make it user friendly but it might be confusing for people reading the source code of the function.
@ogrisel have recently done some work recently porting some sklearn inference code to jax to allow super easy combining with keras and export to tf lite for running on microcontrollers. not quite public, but when it is, i'll ping you :) sklearn + jax is an awesome combo..

@mat_kelcey this is very interesting, thanks for the heads up. We might even have a story to (optionally) use jax at training time for some specific estimators:

- either via the Array API spec that JAX might want to target at some point: https://scikit-learn.org/dev/modules/array_api.html#array-api

- or for Cython powered estimators, via a new plugin system: https://github.com/scikit-learn/scikit-learn/pull/24826

@mat_kelcey for those not familiar with the Array API specification, here is a good intro on the official website:

https://data-apis.org/array-api/latest/purpose_and_scope.html