Skip to content

Support NumPy and PyTorch #26

@NeilGirdhar

Description

@NeilGirdhar

This is now well within reach thanks to th e Array API.

  • For every method, find xp = get_namespace(*arrays) and then use xp instead of jax.numpy.
  • Use the special extension for special functions (xp.special). Depends on RFC: special function extension data-apis/array-api#725.
  • Support sampling methods for:
    • Jax
    • PyTorch
    • NumPy
  • Support native fixed point sampling methods (used in exp-to-nat) for:
    • Jax
    • PyTorch
    • NumPy
  • Generalize abstract_custom_jvp to PyTorch.
  • Port the Fisher information code (which depends on automatic differentiation) to PyTorch.
  • Move automatic JIT-application from methods to tests.
  • Make tests work for each namespace
    • Jax
    • PyTorch
    • NumPy

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions