API for solution methods

The package offers a few different ways to interact with the fast direct solvers generated by HPS algorithms.

Building the Fast Direct Solver

In many use cases, one wants to construct a direct solver and then apply it to many different boundary conditions sequentially. The pair of functions jaxhps.build_solver() and jaxhps.solve() are designed to facilitate this.

Note

jaxhps.build_solver() by default moves data from the GPU to the CPU while building the solution operator. This can significantly slow down the time required to build the solver. If only one solution is required, the subtree recomputation methods may be preferrable.

jaxhps.build_solver(pde_problem: PDEProblem, return_top_T: bool = False, compute_device: Device = jax.devices()[0], host_device: Device = jax.devices('cpu')[0]) None | Array

This function builds all of the matrices for the fast direct solver. This comprises of performing a local solve stage on each leaf, and merging information from the leaves to the root of the domain. If the PDEProblem specifies a uniform 2D problem, and the source term is not specified, this method will build a solver for arbitrary source terms. This requires storing a few more matrices.

This function performs the computation on compute_device and then transfers the data to host_device.

The function will save the solution operators in the PDEProblem object.

The function can optionally return the top-level Poincare–Steklov operator T. This is useful for problems, such as wave scattering, where one wants to couple the solver in the computational domain with a boundary integral equation defined on the domain’s boundary.

To compute solutions of the PDE, one must call the jaxhps.solve() after this one.

Parameters:
  • pde_problem (PDEProblem) – Specifies the differential operator, source, domain, and precomputed interpolation and differentiation matrices.

  • return_top_T (bool, optional) – If set to True, the function will return the computed top-level Poincare–Steklov matrix. Defaults to False.

  • compute_device (jax.Device, optional) – Where the computation should happen. Defaults to jax.devices()[0].

  • host_device (jax.Device, optional) – Where the solution operators should be stored. Defaults to jax.devices(“cpu”)[0].

Returns:

If return_top_T is set to True, the function will return the computed top-level Poincare–Steklov matrix. Otherwise, it returns None.

Return type:

None | jax.Array

jaxhps.solve(pde_problem: PDEProblem, boundary_data: Array | List[Array], source: Array = None, compute_device: Device = jax.devices()[0], host_device: Device = jax.devices('cpu')[0]) Array

This function performs the downward pass of the HPS algorithm, after the solution operators have been formed by a call to jaxhps.build_solver().

If the problem is a 2D uniform problem, the source term can be specified here. For other problems, the source term must be specified at the time the solver is built.

Parameters:
  • pde_problem (PDEProblem) – Specifies the differential operator, source, domain, and precomputed interpolation and differentiation matrices. Also contains all of the solution operators computed by jaxhps.build_solver().

  • boundary_data (jax.Array | List[jax.Array]) – This specifies the data on the boundary of the domain that will be propagated down to the interior of the leaves. If using an adaptive discretization, this must be specified as a list of arrays, one for each side or face of the root boundary. This list can be specified using the Domain.get_adaptive_boundary_data_lst() utility. For uniform discretizations, this argument can be a jax.Array of shape (n_bdry,) or a list.

  • source (jax.Array) – The source term for the PDE. Currently, this can only be specified for 2D uniform ItI problems. For other versions, the source must be specified at the time the solver is built.

  • compute_device (jax.Device, optional) – Where the computation should happen. Defaults to jax.devices()[0].

  • host_device (jax.Device, optional) – Where the solution operators should be stored. Defaults to jax.devices(“cpu”)[0].

Returns:

The solution on the HPS grid. This has shape (n_leaves, p^d).

Return type:

jax.Array

Subtree-Recomputation Solution Methods

To take full advantage of hardware acceleration, we designed subtree-recomputation methods. These methods are useful when you want to solve the PDE for a single right-hand-side very quickly. If you want to impose a Dirichlet or Robin boundary condition and know the boundary data at runtime, you can use jaxhps.solve_subtree(). Otherwise, you can use the pair of functions jaxhps.upward_pass_subtree() and jaxhps.downward_pass_subtree(). The upward pass returns the domain’s Poincare–Steklov operator, which can be used, for example, to define a boundary integral equation specifying boundary data. The downward pass uses the partially-saved solution operators and performs recomputation where necessary.

jaxhps.solve_subtree(pde_problem: PDEProblem, boundary_data: Array | List, subtree_height: int = 7, compute_device: Device = jax.devices()[0], host_device: Device = jax.devices('cpu')[0]) Array

This function solves the PDE using the novel subtree recomputation strategy. This algorithm is only supported for 2D problems with uniform quadtrees. The algorithm proceeds by splitting the problem into a set of subtrees, computing the outgoing data (T and h) for the root of the subtrees, and then performing the highest level merges. This allows us to greatly reduce the number of data movements between the CPU and GPU, at the cost of more floating point operations.

Unlike the jaxhps.build_solver() method, this function does not save any of the solution matrices computed during the upward pass. Thus, it is most appropriate when we want to solve one instance of the problem very quickly.

Parameters:
  • pde_problem (PDEProblem) – Specifies the discretization, differential operator, source function, and keeps track of the pre-computed differentiation and interpolation matrices.

  • boundary_data (jax.Array | List) – Specifies the boundary data on the boundary discretization points. Can be a list or a jax.Array.

  • subtree_height (int, optional) – Height of the subtrees used in our recomputation algorithm. The default is what we found to be optimal for DtN merges using fp64 on an NVIDIA H100 GPU.

  • compute_device (jax.Device, optional) – Where the computation should be performed. This is typically a GPU device. Defaults to jax.devices()[0].

  • host_device (jax.Device, optional) – Device where the returned data lives. This is typically a CPU device. Defaults to jax.devices("cpu")[0].

Returns:

solns – Solutions to the boundary value problem on the HPS grid.

Return type:

jax.Array

jaxhps.upward_pass_subtree(pde_problem: PDEProblem, subtree_height: int = 7, compute_device: Device = jax.devices()[0], host_device: Device = jax.devices('cpu')[0]) Array

Does the upward pass of the subtree recomputation algorithm, returns the top-level Poincare–Steklov matrix, and stores the high-level \(S\) and \(\tilde{g}\) data. This is meant to be used in conjunction with jaxhps.downward_pass_subtree() for large problems where the boundary data must be specified after the upward pass, such as a wave scattering context, where the boundary impedance values can not be computed without the top-level ItI matrix.

Parameters:
  • pde_problem (PDEProblem) – Specifies the discretization, differential operator, source function, and keeps track of the pre-computed differentiation and interpolation matrices.

  • subtree_height (int, optional) – Height of the subtrees used in our recomputation algorithm. The default is what we found to be optimal for DtN merges using fp64 on an NVIDIA H100 GPU.

  • compute_device (jax.Device, optional) – Where the computation should be performed. This is typically a GPU device. Defaults to jax.devices()[0].

  • host_device (jax.Device, optional) – Device where the returned data lives. This is typically a CPU device. Defaults to jax.devices("cpu")[0].

Returns:

T_last – Top-level Poincare–Steklov matrix for the whole domain.

Return type:

jax.Array

jaxhps.downward_pass_subtree(pde_problem: PDEProblem, boundary_data: Array, subtree_height: int = 7, compute_device: Device = jax.devices()[0], host_device: Device = jax.devices('cpu')[0]) Array

Does the downward pass of the subtree recomputation algorithm. This is meant to be used in conjunction with func:jaxhps.upward_pass_subtree for large problems where the boundary data must be specified after the upward pass, such as a wave scattering context, where the boundary impedance values can not be computed without the top-level ItI matrix.

Parameters:
  • pde_problem (PDEProblem) – Specifies the discretization, differential operator, source function, and keeps track of the pre-computed differentiation and interpolation matrices.

  • boundary_data (jax.Array | List) – Specifies the boundary data on the boundary discretization points. Can be a list or a jax.Array.

  • subtree_height (int, optional) – Height of the subtrees used in our recomputation algorithm. Must be the same as used in the upward pass.

  • compute_device (jax.Device, optional) – Where the computation should be performed. This is typically a GPU device. Defaults to jax.devices()[0].

  • host_device (jax.Device, optional) – Device where the returned data lives. This is typically a CPU device. Defaults to jax.devices("cpu")[0].

Returns:

solns – Solutions to the boundary value problem on the HPS grid. Has shape (n_leaves, p^2)

Return type:

jax.Array