Data placement across accelerator devices

By default, JAX will place all new jax.Array objects on an accelerator device, if available. For our use-case, we do not want this behavior, because we would like to initialize large arrays of data, and then move them to the accelerator devices when needed for a computational step. If you plan to use jaxhps to compute solutions of PDEs with large numbers of discretization points, the default JAX behavior may become problematic. Because of this, we suggest to override the default device used by jax by running this line of code:

import jax
jax.config.update("jax_default_device", jax.devices("cpu")[0])

This means that all new jax.Array objects will be created on the CPU by default, and must be moved to the GPU. All of the example scripts described in Examples run this line of code after importing the necessary modules.

In most of the functions described in API for solution methods, there are arguments for compute_device and host_device. The compute_device is where the computation will be performed; the functions will take care of moving the data to the correct device. The host_device is where the data will be stored after the computation is finished; the functions will move the data to that device before returning.