.. jaxhps documentation master file, created by
sphinx-quickstart on Wed Mar 26 11:43:02 2025.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
jaxhps documentation
====================
The ``jaxhps`` package provides utilites for constructing fast, direct solvers for systems of linear elliptic partial differential equations. It uses `jax `_ for hardware-accelerated linear algebra operations.
Please see our preprint `Hardware Acceleration for HPS Algorithms in Two and Three Dimensions `_ for details about the algorithms implemented in this package.
If you find this work useful, please cite our paper::
@misc{melia2025,
title={Hardware Acceleration for {HPS} Algorithms in Two and Three Dimensions},
author={Owen Melia and Daniel Fortunato and Jeremy Hoskins and Rebecca Willett},
year={2025},
eprint={2503.17535},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2503.17535},
}
Source Repository
-------------------
Available on GitHub at ``_.
Installation
----------------
The ``jaxhps`` package requires ``scipy>=1.14`` and ``jax>=0.4``. You can use `pip` to install ``jaxhps`` and its dependencies it directly from PyPI:
.. code:: bash
pip install jaxhps
However, if jax is not already installed, this will install a CPU-only version of jax. If you want to install jax with GPU support, the suggested installation command is:
.. code:: bash
pip install jax[cuda12]
pip install jaxhps
Where ``cuda12`` should be replaced with the appropriate CUDA version for your system. See the `jax installation guide `_ for more details on installing JAX with GPU support.
The examples require additional packages ``matplotlib>=3.8.4`` and ``h5py>=3.11.0``. If you want to install them automatically, use:
.. code:: bash
pip install jaxhps[examples]
Usage quickstart
-------------------
You can use the ``jaxhps`` package to solve systems of linear elliptic PDEs by first specifying the root of the domain, and then specify the parameters for the high-order composite spectral collocation scheme:
.. code:: python
import jaxhps
root = jaxhps.DiscretizationNode2D(xmin=0.0, xmax=1.0, ymin=0.0, ymax=1.0)
domain = jaxhps.Domain(p=16, # polynomial degree of leaf Chebyshev points
q=14, # polynomial degree of boundary Gauss-Legendre points
root=root, # root of the domain tree
L=3, # number of levels in the domain tree
)
The :class:`jaxhps.Domain` object will construct the discretization tree and all of the discretization points. There are utilites provided for high-order polynomial interpolation to and from the discretization points. This example constructs a uniform 2D quadtree with ``L=3`` levels, but the code can also support octrees for 3D problems and non-uniform (adaptive) trees in both 2D and 3D.
You can then define a :class:`jaxhps.PDEProblem` to specify a differential operator and source term. Suppose we want to solve this problem:
.. math::
\Delta u(x) &= 0 \quad \text{in } \Omega \\
u(x) &= x_1^2 - x_2^2 \quad \text{in } \partial\Omega
We can define an instance of ``PDEProblem`` to represent this problem as follows:
.. code:: python
import jax.numpy as jnp
import jaxhps
# It's helpful to use the Domain's quadrature points
source_term = jnp.zeros_like(domain.interior_points[..., 0])
D_xx_coeffs = jnp.ones_like(domain.interior_points[..., 0])
D_yy_coeffs = jnp.ones_like(domain.interior_points[..., 0])
# Create the PDEProblem instance
pde_problem = jaxhps.PDEProblem(domain=domain, # the domain we constructed above
source=source_term,
D_xx_coefficients=D_xx_coeffs,
D_yy_coefficients=D_yy_coeffs
)
This ``PDEProblem`` instance now represents the differential operator and source term for our problem. The coefficients for the differential operator can be constant or can vary spatially, as long as they are defined on the interior points of the domain. Now that the ``PDEProblem`` is defined, we can build a direct solver for it using :func:`jaxhps.build_solver`.
.. code:: python
# Doesn't return anything. Stores solution operators inside the pde_problem instance
jaxhps.build_solver(pde_problem=pde_problem)
Now that the solver has been built, we can apply boundary data to get the solution using :func:`jaxhps.solve`.
.. code:: python
# Define the boundary data
boundary_data = domain.boundary_points[..., 0]**2 - domain.boundary_points[..., 1]**2
# Apply the boundary data to the solver
solution = jaxhps.solve(pde_problem=pde_problem,
boundary_data=boundary_data)
The :func:`jaxhps.solve` function will return the solution on the HPS grid points, which are ordered in a particular way to make the computation easier. To visualize the solution, it's easiest to use the :class:`jaxhps.Domain`'s interpolation utilities to interpolate the solution to a regular grid:
.. code:: python
# Interpolate the solution onto a regular grid for plotting.
n_pixels = 100
x_pts = jnp.linspace(root.xmin, root.xmax, n_pixels)
y_pts = jnp.linspace(root.ymin, root.ymax, n_pixels)
solution_pixels, pixel_locations = domain.interp_from_interior_points(
solution, x_pts, y_pts
)
Now we can use `matplotlib` to visualize the solution. We know the analytical solution of this problem is :math:`u(x) = x_1^2 - x_2^2`, so we can compare the numerical solution to the analytical one:
.. code:: python
import matplotlib.pyplot as plt
# Expected solution for comparison
expected_solution = pixel_locations[..., 0] ** 2 - pixel_locations[..., 1] ** 2
# Plot the computed solution and the deviations from the expected solution.
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
im_0 = ax[0].imshow(
solution_pixels,
extent=(root.xmin, root.xmax, root.ymin, root.ymax),
origin="lower",
)
plt.colorbar(im_0, ax=ax[0])
ax[0].set_title("Computed Solution")
im_1 = ax[1].imshow(
jnp.abs(solution_pixels - expected_solution),
extent=(root.xmin, root.xmax, root.ymin, root.ymax),
origin="lower",
cmap="hot",
)
plt.colorbar(im_1, ax=ax[1])
ax[1].set_title("Errors")
plt.tight_layout()
plt.show()
This should show the following figure. Note that even after interpolation, the solution is within :math:`3 \times 10^{-14}` of the expected solution.
.. image:: images/usage_quickstart.svg
:align: center
:width: 600
:alt: Showing the solution of the quickstart problem, and the deviations from the expected solution.
In the ``jaxhps`` package, there are many more utilities for working with HPS algorithms, including adaptive discretization methods, computing on GPUs, and interpolation to and from the HPS discretization.
.. toctree::
:maxdepth: 1
:caption: Contents:
DiscretizationNode
PDEProblem
solution_methods
method_API
quadrature
Examples
Device_and_data
Contributing