Skip to content

Turn SymPy expressions into trainable JAX expressions.

License

Notifications You must be signed in to change notification settings

APN-Pucky/sympy2jax

Repository files navigation

sympy2jax

Turn SymPy expressions into trainable JAX expressions. The output will be an Equinox module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.

Optimise your symbolic expressions via gradient descent!

Installation

pip install sympy2jax

Requires:
Python 3.7+
JAX 0.3.4+
Equinox 0.5.3+
SymPy 1.7.1+.

Example

import jax
import sympy
import sympy2jax

x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
mod = sympy2jax.SymbolicModule([cosx, sinx])  # PyTree of input expressions

x = jax.numpy.zeros(3)
out = mod(x_sym=x)  # PyTree of results.
params = jax.tree_leaves(mod)  # 1.0 and 2.0 are parameters.
                               # (Which may be trained in the usual way for Equinox.)

Documentation

sympytorch.SymbolicModule(expressions)

Where expressions is a PyTree of SymPy expressions.

Instances can be called with key-value pairs of symbol-value, as in the above example.

Instances have a .sympy() method that translates the module back into a PyTree of SymPy expressions.

(That's literally the entire documentation, it's super easy.)

Finally

See also: other tools in the JAX ecosystem

Neural networks: Equinox.

Numerical differential equation solvers: Diffrax.

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: jaxtyping.

Disclaimer

This is not an official Google product.

About

Turn SymPy expressions into trainable JAX expressions.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%