Examples
In the source repository, we include code for a few example uses of the HPS routines. The code for these examples is not distributed in the jaxhps package, but it is available in the examples directory of the source repository.
Note
These scripts are optimized for GPUs with 80GB of VRAM. If you have a GPU with less VRAM, you may need to reduce the problem size or polynomial degree to avoid out-of-memory errors. These out-of-memory errors look something like this:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate ...
hp convergence on 2D problems with known solutions
Shows convergence using uniform quadtrees with both DtN matrices and ItI matrices. This example uses solution methods jaxhps.build_solver() and jaxhps.solve() to solve the PDE problem.
python examples/hp_convergence_2D_problems.py --DtN --ItI
The example problems being solved are a Dirichlet problem with variable coefficients (DtN case) and a variable-coefficient Helmholtz problem with a Robin boundary condition (ItI case). Plotting the results should show convergence of the \(\ell_\infty\) error at rate \(O(h^{p-2})\), where \(h\) is the side length of the leaves of the discretization tree.
High-wavenumber scattering problem
The high-wavenumber scattering example is a GPU implementation of the solver presented in [1]. The solver is constructed using our jaxhps.upward_pass_subtree() routine to generate a top-level ItI matrix, sets up and solves boundary integral equation to enforce the radiation condition, and then propagates impedance data to the interior points using the jaxhps.downward_pass_subtree() routine.
To run the example, you will need exterior single and double-layer potential matrices. These matrices are necessary to define a boundary integral equation for the scattering problem. You can download these matrices from Zenodo: https://doi.org/10.5281/zenodo.17259087. Alternatively, you can run the MATLAB script examples/driver_gen_SD_matrices.m. This will generate and save the exterior single and double-layer potential matrices; you can also use this script to generate new potential matrices for different domain sizes, discretization levels, and wavenumbers. Once the matrices are in place, we can run the script:
Note
This script only runs and times the code once. To see the large effect of JAX’s just-in-time compilation, you may want to edit the script to compute the solution multiple times.
python examples/wave_scattering_compute_reference_soln.py \
--scattering_potential gauss_bumps \
-k 100 \
--plot_utot
This will generate plots which looks like this, showing the scattering potential and real part and modulus of the total field:
Accuracy of automatic differentiation
In this example, we compare the output of JAX’s automatic differentiation to the action of the Frèchet derivative of the following function:
where \(u\) is the solution to the variable-coefficient Helmholtz equation with scattering potential \(q_\theta\) specified by basis coefficients \(\theta\). We compute the Frèchet derivative by solving auxiliary PDEs; see [2]. We include separate scripts for assessing the accuracy of Jacobian-vector products and vector-Jacobian products, and a third script for plotting:
python examples/check_autodiff_Jvp.py
python examples/check_autodiff_vJp.py
python examples/plot_autodiff_data.py
This produces a figure showing the convergence of the two autodiff methods:
Inverse wave scattering using automatic differentiation
This example is a 2D inverse scattering problem where we try to recover the basis coefficients \(\theta\) of a scattering potential \(q\). Using automatic differentiation with our code is simple. We want to be able to compute Jacobian-vector products:
where \(J[\theta_t]\) is the Jacobian of the forward model evaluated at \(\theta_t\), and \(v\) is an arbitrary vector. We also want to compute vector-Jacobian products:
Computing both of these objects is easy:
import jax
# See the examples directory in the source repo
from inverse_scattering_utils import forward_model
# vjp_fn : v -> v^\top J[\theta_t]
# u_t = forward_model(theta_t)
u_t, vjp_fn = jax.vjp(forward_model, theta_t)
# Need to conjugate because we're using complex numbers
vjp_fn = lambda v: vjp_fn(v.conjugate()).conjugate()
# Jv is the evaluation of J[\theta_t] v, not a function.
_, Jv = jax.vjp(forward_model, (theta_t,), (v,))
To run the example, you need to generate the single and double-layer kernel matrices using the MATLAB script examples/driver_gen_SD_matrices.m, if you haven’t already done so. Once these matrices are in place, you can run the inverse scattering example using the command line:
python examples/inverse_wave_scattering.py --n_iter 20
In this example, we are trying to recover the low-frequency sine basis coefficients \(\theta\) of the scattering potential from the forward wave scattering example. Running the code should produce this plot showing the convergence of the algorithm:
Adaptive discretization on a 3D problem with known solution
We have a script for generating adaptive discretizations on the wavefront problem presented in our paper:
python examples/wavefront_adaptive_discretization_3D.py -p 10 --tol 1e-02 1e-05
This should produce an image showing the computed solution, generated grid, and error map for each specified tolerance level. Here is the result for the tolerance level \(10^{-5}\):
Adaptive discretization on the linearized Poisson–Boltzmann equation
We have a script for generating adaptive discretizations of the linearized Poisson–Boltzmann equation applied to a simulated molecular configuration with 50 atoms:
python examples/poisson_boltzmann_example.py --tol 1e-01 1e-02 -p 10
This should produce output giving information about the generated grid and solution time for each specified tolerance level. In addition, it plots the generated grid with the permittivity. Here is the result for the tolerance level \(10^{-4}\) and polynomial degree \(p=10\):