This is now well within reach thanks to th e Array API. - [x] For every method, find `xp = get_namespace(*arrays)` and then use `xp` instead of `jax.numpy`. - [ ] Use the special extension for special functions (`xp.special`). Depends on https://github.com/data-apis/array-api/issues/725. - [ ] Support sampling methods for: - [x] Jax - [ ] PyTorch - [ ] NumPy - [ ] Support native fixed point sampling methods (used in exp-to-nat) for: - [x] Jax - [ ] PyTorch - [ ] NumPy - [ ] Generalize `abstract_custom_jvp` to PyTorch. - [ ] Port the Fisher information code (which depends on automatic differentiation) to PyTorch. - [ ] Move automatic JIT-application from methods to tests. - [ ] Make tests work for each namespace - [x] Jax - [ ] PyTorch - [ ] NumPy
This is now well within reach thanks to th e Array API.
xp = get_namespace(*arrays)and then usexpinstead ofjax.numpy.xp.special). Depends on RFC: special function extension data-apis/array-api#725.abstract_custom_jvpto PyTorch.