Spectral discretization utilities

The quadrature module provides utilities for computing quadrature rules and weights for numerical integration. It also includes routines for building barycentric Lagrange interpolation matrices.

jaxhps.quadrature.affine_transform(pts: Array, ab: Array) Array

Affine transforms the points pts, which are assumed to be in the interval [-1, 1], to the interval [a, b].

Parameters:
  • pts (jax.Array) – Has shape (n,)

  • ab (jax.Array) – Has shape (2,)

Returns:

Has shape (n,)

Return type:

jax.Array

jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_1D(from_pts: Array, to_pts: Array) Array

Generates a Lagrange 1D polynomial interpolation matrix, which interpolates from the points in from_pts to the points in to_pts.

This function uses the barycentric formula for Lagrange interpolation, from [1]

Parameters:
  • from_pts (jax.Array) – Has shape (n,)

  • to_pts (jax.Array) – Has shape (p,)

Returns:

Has shape (p,n)

Return type:

jax.Array

jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_2D(from_pts_x: Array, from_pts_y: Array, to_pts_x: Array, to_pts_y: Array) Array

2D Barycentric Lagrange interpolation matrix. A generalization of [1], modeled after the MATLAB code snippet [2].

The grid of source points is specified by from_pts_x and from_pts_y. The resulting matrix has columns ordered to map from samples on this list of points:

source_X, source_Y = jnp.meshgrid(from_pts_x, from_pts_y, indexing="ij")
source_pts = jnp.stack((source_X.flatten(), source_Y.flatten()), axis=-1)

Similarly, the rows are ordered to assume a grid of target points specified by:

target_X, target_Y = jnp.meshgrid(to_pts_x, to_pts_y, indexing="ij")
target_pts = jnp.stack((target_X.flatten(), target_Y.flatten()), axis=-1)
Parameters:
  • from_pts_x (jax.Array) – Has shape (n_x,)

  • from_pts_y (jax.Array) – Has shape (n_y,)

  • to_pts_x (jax.Array) – Has shape (p_x,)

  • to_pts_y (jax.Array) – Has shape (p_y,)

Returns:

Has shape (p_x * p_y, n_x * n_y)

Return type:

jax.Array

jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_3D(from_pts_x: Array, from_pts_y: Array, from_pts_z: Array, to_pts_x: Array, to_pts_y: Array, to_pts_z: Array) Array

3D Barycentric Lagrange interpolation matrix. A generalization of [1].

The grid of source points is specified by from_pts_x, from_pts_y, and from_pts_z. The resulting matrix has columns ordered to map from samples on this list of points:

source_X, source_Y, source_Z = jnp.meshgrid(from_pts_x, from_pts_y, from_pts_z indexing="ij")
source_pts = jnp.stack((source_X.flatten(), source_Y.flatten(), source_Z.flatten()), axis=-1)

Similarly, the rows are ordered to assume a grid of target points specified by:

target_X, target_Y, target_Z = jnp.meshgrid(to_pts_x, to_pts_y, to_pts_z, indexing="ij")
target_pts = jnp.stack((target_X.flatten(), target_Y.flatten(), target_Z.flatten()), axis=-1)
Parameters:
  • from_pts_x (jax.Array) – Has shape (n_x,)

  • from_pts_y (jax.Array) – Has shape (n_y,)

  • from_pts_z (jax.Array) – Has shape (n_z,)

  • to_pts_x (jax.Array) – Has shape (p_x,)

  • to_pts_y (jax.Array) – Has shape (p_y,)

  • to_pts_z (jax.Array) – Has shape (p_z,)

Returns:

Has shape (p_x * p_y * p_z, n_x * n_y * n_z)

Return type:

jax.Array

jaxhps.quadrature.chebyshev_points(n: int) Array

Returns n Chebyshev points over the interval [-1, 1]

out[i] = cos(pi * (n-1 - i) / (n-1)) for i={0,…,n-1}

The left side of the interval is returned first.

Parameters:

n (int) – number of Chebyshev points to return

Returns:

The sampled points in [-1, 1] and the corresponding angles in [0, pi]

Return type:

jax.Array

jaxhps.quadrature.chebyshev_weights(n: int, bounds: array) array

Generates weights for a Chebyshev quadrature rule with n points over the interval [a, b].

Uses the Clenshaw-Curtis quadrature rule, specifically the version used in Chebfun. See [3].

Parameters:
  • n (int) – Number of quadrature points

  • bounds (jnp.array) – Has shape (2,) and contains the interval endpoints [a, b]

Returns:

Has shape (n,) and contains the quadrature weights

Return type:

jnp.array

jaxhps.quadrature.differentiation_matrix_1D(points: Array) Array

Creates a 1-D Chebyshev differentiation matrix as described in [4] Ch 6.

Expects Chebyshev points on the interval [-1, 1].

Parameters:

points (jnp.ndarray) – Has shape (p,)

Returns:

Has shape (p,p)

Return type:

jnp.ndarray

jaxhps.quadrature.gauss_points(n: int) Array

Returns n Gauss-Legendre points over the interval [-1, 1]. This is a wrapper for numpy.polynomial.legendre.leggauss.

Parameters:

n (int) – Number of points

Returns:

Has shape (n,)

Return type:

jax.Array

References