Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

License

Notifications You must be signed in to change notification settings

NEXUS0/diffrax

Open more actions menu
 
 

Repository files navigation

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python >=3.7 and JAX >=0.2.27.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

About

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%
Morty Proxy This is a proxified and sanitized view of the page, visit original site.