# hostsub_gp/interp.py
import jax
import jax.numpy as jnp
from functools import partial
from typing import Callable
from jax._src.typing import ArrayLike, Array
from typing import Optional
from ._utils import msgs
# jax.config.update("jax_enable_x64", True)
###################################################################################################
################################# Interpolation on a regular grid #################################
###################################################################################################
[docs]
class Interp1D_Grid:
"""
1D interpolation using a regular grid.
A wrapper around jax.numpy.interp.
"""
def __init__(self, points: Array, values: Array, method="linear"):
self.method = method
self.points = points
self.values = jnp.asarray(values)
@partial(jax.jit, static_argnums=(0,))
def __call__(self, x):
return jax.vmap(lambda x: jnp.interp(x, self.points, self.values))(x)
[docs]
class Interp2D_Grid:
"""
2D interpolation using a regular grid.
A wrapper around jax.scipy.interpolate.RegularGridInterpolator.
"""
def __init__(self, points: tuple[Array, Array], values: Array, method="linear"):
self.method = method
self.points = points
self.values = values
def __call__(self, x):
if self.method in ["nearest", "linear"]:
return jax.scipy.interpolate.RegularGridInterpolator(
self.points, self.values, method=self.method
)(x)
###################################################################################################
################################# Interpolation on a irregular grid ###############################
###################################################################################################
from abc import ABC, abstractmethod
[docs]
class Interp2D_base(ABC):
"""
Abstract base class for 2D interpolation on semi-uniform grids.
"""
def __init__(self, scales: tuple | ArrayLike = (1, 1)):
"""
Initialize base interpolator for semi-uniform grid (nx, ny).
Parameters
----------
scales : tuple or ArrayLike
Scales for each dimension
"""
self.scales = jnp.asarray(scales)
self.points = None
self.values = None
self.shape = None
self.x_grid = None
self.y_grid = None
[docs]
def fit(self, points: ArrayLike, values: ArrayLike) -> None:
"""
Store the training data and grid information.
Parameters
----------
points : ArrayLike
Array of shape (nx, ny, 2) containing the 2D coordinates
values : ArrayLike
Array of shape (nx, ny) containing the values at each point
"""
points = jnp.asarray(points, dtype=jnp.float32)
values = jnp.asarray(values, dtype=jnp.float32)
self.shape = points.shape[:2]
if values.shape != self.shape:
raise ValueError(
f"Shape mismatch between points {(points.shape)} and values {(values.shape)}"
)
self.points = points / self.scales
self.values = jnp.asarray(values, dtype=jnp.float32)
# Separate x and y coordinates for easier access
self.x_grid = self.points[..., 0]
self.y_grid = self.points[..., 1]
# @partial(jax.jit, static_argnums=(0,))
def _find_nearest_cell(self, query_point: Array) -> tuple[Array, Array]:
"""
Find the cell in the semi-uniform grid containing or nearest to the query point
using binary search (searchsorted).
Parameters
----------
query_point : Array
Point to interpolate at, shape (2,)
Returns
-------
Tuple[Array, Array]
i, j: indices of the lower-left corner of the containing cell
"""
qx, qy = query_point
if self.shape is None or self.x_grid is None or self.y_grid is None:
raise ValueError(
"Interpolator not fitted yet. Call fit() with training data."
)
# First search for x position in middle column
mid_row = self.shape[1] // 2
i_mid = jnp.searchsorted(self.x_grid[:, mid_row], qx) - 1
i_mid = jnp.clip(i_mid, 0, self.shape[0] - 2)
# Get y values for this row
y_row = self.y_grid[i_mid, :]
j_index = jnp.searchsorted(y_row, qy) - 1
j_index = jnp.clip(j_index, 0, self.shape[1] - 2)
# Refine x search in the found column
i_index = jnp.searchsorted(self.x_grid[:, j_index], qx) - 1
i_index = jnp.clip(i_index, 0, self.shape[0] - 2)
return i_index, j_index
# @partial(jax.jit, static_argnums=(0, 3, 4))
def _get_cell_points_and_values(
self, i: Array, j: Array, di: int = 4, dj: int = 2
) -> tuple[Array, Array]:
"""
Get the neighboring cells of a grid cell and their values.
Parameters
----------
i, j : int
Indices of the lower-left corner of the cell
di, dj : int, optional
Number of cells to include in each direction, default is 2
Returns
-------
Tuple[Array, Array]
points: Array of shape (4, 2) containing corner coordinates
values: Array of shape (4,) containing corner values
"""
if (
self.values is None
or self.points is None
or self.shape is None
or self.x_grid is None
or self.y_grid is None
):
raise ValueError(
"Interpolator not fitted yet. Call fit() with training data."
)
i_idx = jnp.arange(-di // 2, di // 2)
j_idx = jnp.arange(-dj // 2, dj // 2)
i_idx = jnp.clip(i + i_idx, 0, self.shape[0] - 1)
j_idx = jnp.clip(j + j_idx, 0, self.shape[1] - 1)
# Stacking all di x dj points (with meshgrid)
i_grid, j_grid = jnp.meshgrid(i_idx, j_idx, indexing="ij")
i_grid = i_grid.ravel()
j_grid = j_grid.ravel()
points = jnp.stack(
[self.x_grid[i_grid, j_grid], self.y_grid[i_grid, j_grid]], axis=1
)
values = self.values[i_grid, j_grid]
return points, values
@abstractmethod
def _compute_weights(
self,
cell_points: Array,
query_point: Array,
cell_values: Optional[Array] = None,
) -> Array:
"""
Compute interpolation weights for given cell points and query point.
To be implemented by subclasses.
Parameters
----------
cell_points : Array
Corner points of the cell, shape (4, 2)
query_point : Array
Point to interpolate at, shape (2,)
cell_values : Array, optional
Values at the corner points, shape (4,)
Returns
-------
Array
Weights for interpolation, shape (4,)
"""
raise NotImplementedError
# @partial(jax.jit, static_argnums=(0,))
def _interpolate_single(self, query_point: Array) -> Array:
"""
Interpolate value at a single query point with NaN handling
"""
# Find containing/nearest cell
i, j = self._find_nearest_cell(query_point)
# Get cell points and values
cell_points, cell_values = self._get_cell_points_and_values(i, j)
# Check for NaN values
valid_mask = ~jnp.isnan(cell_values)
n_valid = jnp.sum(valid_mask)
def interpolate():
# Compute weights using subclass implementation
weights = self._compute_weights(cell_points, query_point, cell_values)
# Handle NaN values by redistributing weights
weights = jnp.where(valid_mask, weights, 0.0)
weights = weights / (jnp.sum(weights) + 1e-10)
return jnp.sum(jnp.where(valid_mask, cell_values, 0.0) * weights)
# Return NaN if not enough valid values
return jax.lax.cond(n_valid >= 3, lambda: interpolate(), lambda: jnp.nan)
[docs]
def predict(self, query_points: ArrayLike) -> Array:
"""
Make predictions at new points with NaN handling
Parameters
----------
query_points : ArrayLike
Array of shape (n_queries, 2) containing points to interpolate
Returns
-------
Array
Array of interpolated values at query_points
"""
query_points = jnp.asarray(query_points)
# Check for NaN in query points
valid_queries = ~jnp.any(jnp.isnan(query_points), axis=1)
# Scale query points
scaled_points = query_points / self.scales
# Vectorize the single point interpolation
predictions = jax.vmap(self._interpolate_single)(scaled_points)
# Ensure NaN for invalid query points
return jnp.where(valid_queries, predictions, jnp.nan)
[docs]
class Interp2D_Linear(Interp2D_base):
"""
2D bilinear interpolation on semi-uniform grid.
"""
def _compute_weights(
self,
cell_points: Array,
query_point: Array,
cell_values: Optional[Array] = None,
) -> Array:
"""
Compute weights for bilinear interpolation
"""
# Weights proportional to the inverse of the distance
distances = jnp.linalg.norm(cell_points - query_point, axis=1)
# Handle zero distance (exact match)
return jax.lax.cond(
jnp.any(distances < 1e-10),
lambda: jnp.where(distances == 0, 1.0, 0.0),
lambda: 1.0 / distances,
)
[docs]
class Interp2D_RBF(Interp2D_base):
"""
2D interpolation using radial basis functions.
"""
def __init__(
self,
kernel: str = "gaussian",
epsilon: float = 1.0,
scales: tuple | ArrayLike = (1, 1),
):
"""
Initialize RBF interpolator
Parameters
----------
kernel : str
Type of RBF kernel ('gaussian', 'multiquadric', or 'inverse_multiquadric')
epsilon : float
Shape parameter for the RBF kernel
Minimum number of valid neighbors required for interpolation
scales : tuple or ArrayLike
Scales for each dimension
"""
super().__init__(scales)
self.epsilon = epsilon
self.kernel = self._get_kernel(kernel)
def _get_kernel(self, kernel_name: str) -> Callable:
"""Define the RBF kernel function"""
def gaussian(r):
return jnp.where(jnp.isfinite(r), jnp.exp(-((self.epsilon * r) ** 2)), 0)
def multiquadric(r):
return jnp.where(jnp.isfinite(r), jnp.sqrt(1 + (self.epsilon * r) ** 2), 0)
def inverse_multiquadric(r):
return jnp.where(
jnp.isfinite(r), 1 / jnp.sqrt(1 + (self.epsilon * r) ** 2), 0
)
kernels = {
"gaussian": gaussian,
"multiquadric": multiquadric,
"inverse_multiquadric": inverse_multiquadric,
}
return kernels[kernel_name]
def _compute_weights(
self, cell_points: Array, query_points: Array, cell_values: Array
) -> Array:
"""
Compute weights using RBF kernel
"""
distances = jnp.linalg.norm(cell_points - query_points, axis=1)
kernel_matrix = self.kernel(distances)
try:
weights = jnp.linalg.solve(kernel_matrix, cell_values)
return self.kernel(distances) @ weights
except:
msgs.warning("Singular matrix in RBF interpolation")
return jax.lax.cond(
jnp.any(distances < 1e-10),
lambda: jnp.where(distances == 0, 1.0, 0.0),
lambda: 1.0 / distances,
)
[docs]
class Interp2D_Scipy(Interp2D_base):
"""
2D interpolation using scipy's griddata.
"""
def __init__(
self,
method: str = "linear",
scales: tuple | ArrayLike = (1, 1),
):
"""
Initialize Scipy interpolator
Parameters
----------
method : str
Type of interpolation method ('linear', 'nearest', 'cubic')
scales : tuple or ArrayLike
Scales for each dimension
"""
super().__init__(scales)
self.method = method
def _compute_weights(
self, cell_points: Array, query_point: Array, dx: float, dy: float
) -> Array:
"""
Compute weights using scipy's griddata. This is a dummy implementation.
Actual interpolation is done in the predict method.
"""
return jnp.ones(4) # This is not used in this class
[docs]
def predict(self, query_points: ArrayLike) -> Array:
"""
Make predictions at new points with scipy's griddata
Parameters
----------
query_points : ArrayLike
Array of shape (n_queries, 2) containing points to interpolate
Returns
-------
Array
Array of interpolated values at query_points
"""
from scipy.interpolate import griddata
if self.points is None or self.values is None:
raise ValueError(
"Interpolator not fitted yet. Call fit() with training data."
)
query_points = jnp.asarray(query_points) / self.scales
# Flatten the grid points and values
flat_points = self.points.reshape(-1, 2)
flat_values = self.values.ravel()
# Use scipy's griddata for interpolation
interpolated_values = griddata(
flat_points, flat_values, query_points, method=self.method
)
return jnp.asarray(interpolated_values)