# hostsub_gp/gp.py
__all__ = ["GP"]
import jax
import jax.numpy as jnp
# jax.config.update("jax_enable_x64", True)
import jaxopt
from functools import partial
from tinygp import GaussianProcess, kernels, transforms
from tinygp.helpers import JAXArray
from tinygp.kernels.distance import L2Distance
from jax._src.typing import ArrayLike, Array
from typing import Optional
from ._utils._par import (
# _transform_unbound_to_bound,
# _transform_bound_to_unbound,
init_params,
init_params_limit,
check_params,
print_params,
)
from ._utils import msgs
import warnings
[docs]
class GP:
"""
Gaussian Process.
A wrapper around tinygp.GaussianProcess.
"""
def __init__(
self,
kernel_type: str,
X: ArrayLike,
y: Optional[ArrayLike] = None,
yerr: Optional[ArrayLike | float] = None,
params: Optional[dict] = None,
params_init: Optional[dict] = None,
params_limit: Optional[dict] = None,
optimization: bool = False,
**kwargs,
):
"""Initialize the Gaussian Process."""
# Initialize the input arrays
self.X = jnp.asarray(X)
if y is None:
if optimization:
raise ValueError("Optimization: y must be provided")
else:
self.y = jnp.asarray(y)
if yerr is None:
self.yerr = jnp.zeros_like(y)
elif isinstance(yerr, (int, float)):
self.yerr = jnp.ones_like(y) * yerr
else:
self.yerr = jnp.asarray(yerr)
self.kernel_type = kernel_type
if params is None:
params = {}
if params_init is None:
params_init = {}
# Initialize the parameters
self.params_limit = params_limit if params_limit is not None else {}
if optimization:
try:
self.params_init = init_params(params_init)
except Exception as e:
raise ValueError("Optimization: " + str(e))
assert isinstance(self.params_init, dict), (
"params_init must be a dictionary"
)
self.params = self._optimize(self.X, self.y, self.yerr)
print_params(self.params)
else:
self.params = params
# Build the GP
self.gp = _build_gp(self.params, self.X, self.yerr)(
kernel_type=self.kernel_type, **kwargs
)
def _optimize(self, X: Array, y: Array, yerr: Array) -> dict:
"""Optimize the hyperparameters of the Gaussian Process with jaxopt.ScipyBoundMinimize."""
valid = jnp.isfinite(y)
# Check if the initial parameters are valid
neg_log_prob_init = _neg_log_prob(
self.params_init,
X=X[valid],
y=y[valid],
yerr=yerr[valid],
kernel_type=self.kernel_type,
)
if ~jnp.isfinite(neg_log_prob_init):
msgs.error("Initial log-probability is infinite.")
msgs.info("Initial parameters:")
print_params(self.params_init)
msgs.info("Parameter limits:")
print_params(self.params_limit)
raise ValueError("Invalid initial parameters: please check the limits.")
solver = jaxopt.ScipyBoundedMinimize(
fun=partial(
_neg_log_prob,
X=X[valid],
y=y[valid],
yerr=yerr[valid],
kernel_type=self.kernel_type,
),
method="L-BFGS-B",
)
soln = solver.run(
self.params_init,
bounds=init_params_limit(self.params_init, self.params_limit),
)
msgs.info(f"Initial negative log-probability: {neg_log_prob_init:.1f}")
msgs.info(f"Final negative log-probability: {soln.state.fun_val:.1f}")
return soln.params
[docs]
def predict(
self, X_test: Array, return_var: bool = False
) -> Array | tuple[Array, Array]:
"""Predict the mean and variance of the Gaussian Process at the input points."""
X_test = jnp.asarray(X_test)
# The 1D GP uses the quasiseparable kernel to speed up the computation
# which requires the input to be a 1D array
if X_test.shape[-1] == 1:
X_test = jnp.asarray(X_test.ravel())
return self.gp.predict(self.y, X_test, return_var=return_var)
[docs]
def log_probability(self, y: ArrayLike) -> JAXArray:
"""Log-probability of the Gaussian Process."""
return self.gp.log_probability(jnp.asarray(y))
@partial(jax.jit, static_argnames=("kernel_type",))
def _neg_log_prob(
params: dict, X: Array, y: Array, yerr: Array, kernel_type: str
) -> JAXArray:
"""Negative log-probability of the Gaussian Process."""
# params_bound = _transform_unbound_to_bound(params, params_limit)
# assert isinstance(params_bound, dict), "params must be a dictionary"
# gp = _build_gp(params_bound, X, yerr)(kernel_type=kernel_type)
gp = _build_gp(params, X, yerr)(kernel_type=kernel_type)
neg_log_prob = -gp.log_probability(y)
return neg_log_prob
class _build_gp:
"""Build the Gaussian Process."""
def __init__(self, params: dict, X: Array, yerr: Array):
# Check if necessary parameters are provided
try:
check_params(params)
except ValueError as e:
raise ValueError("Building GP: " + str(e))
# Initialize the parameters
self.log_amp = params.get("log_amp")
self.log_scale = params.get("log_scale")
self.mean = params.get("mean")
# For EmissionLine kernel only
self.log_amp_line = params.get("log_amp_line")
self.scale_line = params.get("scale_line")
self.ndim = X.shape[-1]
if self.ndim == 1:
self.X = jnp.asarray(X.ravel())
elif self.ndim == 2:
self.X = jnp.asarray(X)
else:
raise ValueError(
"Invalid number of dimensions: supported values are 1 or 2"
)
self.yerr = jnp.asarray(yerr)
def __call__(self, kernel_type: str, **kwargs) -> GaussianProcess:
kernel = self._build_kernel(kernel_type, **kwargs)
return GaussianProcess(
kernel=kernel, X=self.X, diag=self.yerr**2, mean=self.mean
)
def _build_kernel(self, kernel_type: str, **kwargs) -> kernels.Kernel:
"""
Build the kernel.
- HostProfile: 2D GP for the host profile - mean function for the consequent GP models
- 1D: 1D GP for the 1D spectrum
- 2D: 2D GP for the 2D spectrum
"""
assert self.log_amp is not None, "log_amp must be provided"
assert self.log_scale is not None, "log_scale must be provided"
amp = 10**self.log_amp
scale = 10**self.log_scale
# Model the host profile prior - 2D GP
# (ExpSquared + Matern52) x (ExpSquared + Matern52)
if kernel_type == "HostProfile":
if self.ndim != 2:
raise ValueError("HostProfile kernel: X must be a 2D array")
return _build_gp._build_2D_composite_kernel(amp, scale)
# Model the 1D spectrum - 1D GP
# ExpSquared + Matern52
elif kernel_type == "1D":
if self.ndim != 1:
raise ValueError("1D kernel: X must be a 1D array")
return _build_gp._build_1D_composite_kernel(amp, scale)
# Model the 2D spectrum - 2D GP
# Matern52 * EmissionLine
elif kernel_type == "2D":
if self.ndim != 2:
raise ValueError("2D kernel: X must be a 2D array")
if self.log_amp_line is None or self.scale_line is None:
raise ValueError(
"EmissionLine kernel requires 'amp_line', and 'scale_line' parameters"
)
emission_lines = kwargs.get("emission_lines")
assert emission_lines is not None, "emission_lines must be provided"
if emission_lines is None:
warnings.warn(
"EmissionLine kernel: emission_lines not provided, the kernel is equivalent to ExpSquared"
)
# # Use transforms.Linear to handle anisotropic kernels
# base_kernel = amp * transforms.Linear(1 / scale, kernel=kernels.Matern52(distance=L2Distance()))
base_kernel = _build_gp._build_2D_single_kernel(amp, scale)
emission_line_kernel = EmissionLineKernel(
amp_line=10**self.log_amp_line,
scale_line=self.scale_line,
emission_lines=emission_lines,
)
kernel = base_kernel * emission_line_kernel
# Invalid kernel type
else:
raise ValueError(
"Invalid kernel type: supported types are 'HostProfile', '1D', '2D'"
)
return kernel
@staticmethod
def _build_1D_composite_kernel(amp: Array, scale: Array) -> kernels.Kernel:
"""
Build a composite kernel for 1D data.
Kernel (quasiseparable) = Matern52 + Matern32
"""
if amp.shape != (2,):
raise ValueError(f"Invalid amplitude shape {amp.shape}: expected (2,)")
# kernel1 : Exp - long-term variations (continuum)
kernel_exp = amp[0] * kernels.quasisep.Matern52(scale=scale[0])
# kernel2 : Matern - short-term variations (sky lines, emission lines)
kernel_matern = amp[1] * kernels.quasisep.Matern32(scale=scale[1])
return kernel_exp + kernel_matern
@staticmethod
def _build_2D_single_kernel(amp: Array, scale: Array) -> kernels.Kernel:
"""
Build a single kernel for 2D grid data.
Kernel = Matern32
"""
if amp.shape != ():
raise ValueError(f"Invalid amplitude shape {amp.shape}: expected ()")
# Use transforms.Linear to handle anisotropic kernels
return amp * transforms.Linear(
1 / scale, kernel=kernels.Matern32(distance=L2Distance())
)
@staticmethod
def _build_2D_composite_kernel(amp: Array, scale: Array) -> kernels.Kernel:
"""
Build a composite kernel for 2D grid data.
Kernel = (ExpSquared + Matern32) x (ExpSquared + Matern32)
"""
# TODO: only 3 free parameters in amp
if amp.shape != (2, 2):
raise ValueError(f"Invalid amplitude shape {amp.shape}: expected (2, 2)")
# Spatially varying parameters
# kernel1 : Exp - long-term variations (continuum)
# Evaluate the kernel only on the spatial coordinates
kernel_spat_exp = amp[0, 0] * kernels.ExpSquared(scale=scale[0, 0])
# kernel2 : Matern - short-term variations (sky lines, emission lines)
kernel_spat_matern = amp[0, 1] * kernels.Matern32(scale=scale[0, 1])
kernel_spat = OneDKernel(kernel=kernel_spat_exp + kernel_spat_matern, axis=0)
# spectral varying parameters
# kernel1 : ExpSquared - long-term variations (continuum)
kernel_spec_exp = amp[1, 0] * kernels.ExpSquared(scale=scale[1, 0])
# kernel2 : Matern - short-term variations
kernel_spec_matern = amp[1, 1] * kernels.Matern32(scale=scale[1, 1])
kernel_spec = OneDKernel(kernel=kernel_spec_exp + kernel_spec_matern, axis=1)
return kernel_spat * kernel_spec
class OneDKernel(kernels.Kernel):
"""A kernel only evaluated on the spatial coordinates."""
kernel: kernels.Kernel
axis: int
def evaluate(self, X1: Array, X2: Array) -> Array:
return self.kernel.evaluate(X1[..., self.axis], X2[..., self.axis])
class EmissionLineKernel(kernels.Kernel):
"""A kernel to handle discontinuities at narrow emission lines in the spectroscopic data."""
amp_line: float
scale_line: float
emission_lines: list | Array
def evaluate(self, X1: Array, X2: Array) -> Array:
"""Evaluate the kernel."""
# Split coordinates
x1_spec, x2_spec = X1[..., -1], X2[..., -1]
k_x1_x2 = jnp.array(1.0)
# Add emission line effects
for line in self.emission_lines:
# Calculate proximity to emission line for each point
# x1_close = _gaussian(x1_spec, line, self.scale_line)
# x2_close = _gaussian(x2_spec, line, self.scale_line)
# x1_close = _sigmoid(x1_spec, line, self.scale_line)
# x2_close = _sigmoid(x2_spec, line, self.scale_line)
# x1_close = _tophat(x1_spec, line, self.scale_line)
# x2_close = _tophat(x2_spec, line, self.scale_line)
x1_close = _hyperbolic_tangent(x1_spec, line, self.scale_line)
x2_close = _hyperbolic_tangent(x2_spec, line, self.scale_line)
# Effect when both x1 and x2 are close to the line
both_close = x1_close * x2_close
# Effect when exactly one of x1 and x2 is close to the line
one_close = x1_close * (1 - x2_close) + (1 - x1_close) * x2_close
# Emission line effect
# - decrease the covariance when exactly one point is close to the line
# - increase the covariance when both points are close to the line
emission_line_effect = 1.0 - one_close + self.amp_line * both_close
k_x1_x2 *= emission_line_effect
return k_x1_x2
@jax.jit
def _sigmoid(x: Array, mu: Array, width: Array) -> Array:
return 1 / (1 + jnp.exp((jnp.abs(x - mu) / width - 1.0) / 0.1))
@jax.jit
def _tophat(x: Array, mu: Array, width: Array) -> Array:
return jnp.where(jnp.abs(x - mu) < width, 1.0, 0.0)
@jax.jit
def _gaussian(x: Array, mu: Array, width: Array) -> Array:
return jnp.exp(-0.5 * (x - mu) ** 2 / width**2)
@jax.jit
def _hyperbolic_tangent(x: Array, mu: Array, width: Array) -> Array:
x_min, x_max = mu - width / 2, mu + width / 2
return 0.5 * (jnp.tanh(2 * (x - x_min) / width) - jnp.tanh(2 * (x - x_max) / width))