Spectral discretization utilities
The quadrature module provides utilities for computing quadrature rules and weights for numerical integration. It also includes routines for building barycentric Lagrange interpolation matrices.
- jaxhps.quadrature.affine_transform(pts: Array, ab: Array) Array
Affine transforms the points pts, which are assumed to be in the interval [-1, 1], to the interval [a, b].
- Parameters:
pts (jax.Array) – Has shape (n,)
ab (jax.Array) – Has shape (2,)
- Returns:
Has shape (n,)
- Return type:
jax.Array
- jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_1D(from_pts: Array, to_pts: Array) Array
Generates a Lagrange 1D polynomial interpolation matrix, which interpolates from the points in from_pts to the points in to_pts.
This function uses the barycentric formula for Lagrange interpolation, from [1]
- Parameters:
from_pts (jax.Array) – Has shape (n,)
to_pts (jax.Array) – Has shape (p,)
- Returns:
Has shape (p,n)
- Return type:
jax.Array
- jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_2D(from_pts_x: Array, from_pts_y: Array, to_pts_x: Array, to_pts_y: Array) Array
2D Barycentric Lagrange interpolation matrix. A generalization of [1], modeled after the MATLAB code snippet [2].
The grid of source points is specified by
from_pts_xandfrom_pts_y. The resulting matrix has columns ordered to map from samples on this list of points:source_X, source_Y = jnp.meshgrid(from_pts_x, from_pts_y, indexing="ij") source_pts = jnp.stack((source_X.flatten(), source_Y.flatten()), axis=-1)
Similarly, the rows are ordered to assume a grid of target points specified by:
target_X, target_Y = jnp.meshgrid(to_pts_x, to_pts_y, indexing="ij") target_pts = jnp.stack((target_X.flatten(), target_Y.flatten()), axis=-1)
- Parameters:
from_pts_x (jax.Array) – Has shape (n_x,)
from_pts_y (jax.Array) – Has shape (n_y,)
to_pts_x (jax.Array) – Has shape (p_x,)
to_pts_y (jax.Array) – Has shape (p_y,)
- Returns:
Has shape (p_x * p_y, n_x * n_y)
- Return type:
jax.Array
- jaxhps.quadrature.barycentric_lagrange_interpolation_matrix_3D(from_pts_x: Array, from_pts_y: Array, from_pts_z: Array, to_pts_x: Array, to_pts_y: Array, to_pts_z: Array) Array
3D Barycentric Lagrange interpolation matrix. A generalization of [1].
The grid of source points is specified by
from_pts_x,from_pts_y, andfrom_pts_z. The resulting matrix has columns ordered to map from samples on this list of points:source_X, source_Y, source_Z = jnp.meshgrid(from_pts_x, from_pts_y, from_pts_z indexing="ij") source_pts = jnp.stack((source_X.flatten(), source_Y.flatten(), source_Z.flatten()), axis=-1)
Similarly, the rows are ordered to assume a grid of target points specified by:
target_X, target_Y, target_Z = jnp.meshgrid(to_pts_x, to_pts_y, to_pts_z, indexing="ij") target_pts = jnp.stack((target_X.flatten(), target_Y.flatten(), target_Z.flatten()), axis=-1)
- Parameters:
from_pts_x (jax.Array) – Has shape (n_x,)
from_pts_y (jax.Array) – Has shape (n_y,)
from_pts_z (jax.Array) – Has shape (n_z,)
to_pts_x (jax.Array) – Has shape (p_x,)
to_pts_y (jax.Array) – Has shape (p_y,)
to_pts_z (jax.Array) – Has shape (p_z,)
- Returns:
Has shape (p_x * p_y * p_z, n_x * n_y * n_z)
- Return type:
jax.Array
- jaxhps.quadrature.chebyshev_points(n: int) Array
Returns n Chebyshev points over the interval [-1, 1]
out[i] = cos(pi * (n-1 - i) / (n-1)) for i={0,…,n-1}
The left side of the interval is returned first.
- Parameters:
n (int) – number of Chebyshev points to return
- Returns:
The sampled points in [-1, 1] and the corresponding angles in [0, pi]
- Return type:
jax.Array
- jaxhps.quadrature.chebyshev_weights(n: int, bounds: array) array
Generates weights for a Chebyshev quadrature rule with n points over the interval [a, b].
Uses the Clenshaw-Curtis quadrature rule, specifically the version used in Chebfun. See [3].
- Parameters:
n (int) – Number of quadrature points
bounds (jnp.array) – Has shape (2,) and contains the interval endpoints [a, b]
- Returns:
Has shape (n,) and contains the quadrature weights
- Return type:
jnp.array
- jaxhps.quadrature.differentiation_matrix_1D(points: Array) Array
Creates a 1-D Chebyshev differentiation matrix as described in [4] Ch 6.
Expects Chebyshev points on the interval [-1, 1].
- Parameters:
points (jnp.ndarray) – Has shape (p,)
- Returns:
Has shape (p,p)
- Return type:
jnp.ndarray
- jaxhps.quadrature.gauss_points(n: int) Array
Returns n Gauss-Legendre points over the interval [-1, 1]. This is a wrapper for
numpy.polynomial.legendre.leggauss.- Parameters:
n (int) – Number of points
- Returns:
Has shape (n,)
- Return type:
jax.Array