Source code for hostsub_gp.spec_wrapper

# hostsub_gp/spec_wrapper.py
# import numpy as np

import jax
import jax.numpy as jnp

# jax.config.update("jax_enable_x64", True)

from functools import partial
from ._utils import msgs

from jax._src.typing import ArrayLike, Array
from typing import Optional


[docs] class SpecWrapper: """A wrapper for the 1D and 2D spectra.""" def __init__( self, points: Array | tuple[Array, Array], values: Optional[Array] = None, values_err: Optional[Array] = None, ): """ Initialize the SpecWrapper object. Parameters ---------- points : ArrayLike | tuple[ArrayLike, ArrayLike] The coordinates of the spectrum. values : ArrayLike, optional The values of the spectrum. values_err : ArrayLike, optional The errors of the values. """ # Loading the coordinates # Input = spatial and spectral axes of the 2D spectrum if isinstance(points, tuple): self.spat, self.spec = jnp.asarray(points[0]), jnp.asarray(points[1]) self.spec_img, self.spat_img = jnp.meshgrid(self.spec, self.spat) self.X = jnp.stack([self.spat_img.ravel(), self.spec_img.ravel()], axis=-1) # Input = spectral axis of the 1D spectrum else: if points.ndim != 1: raise ValueError("Invalid shape of the input coordinates.") self.spat = None # 1D spectrum, no spatial axis self.spec = self.spec_img = jnp.array(points) self.X = self.spec[:, None] self.shape = self.spec_img.shape # Loading the values and errors if values is None: # No values are provided, only coordinates are loaded self.Y = self.Yerr = None self.y = self.yerr = None else: if not ( ((values.ndim == 1) | (values.ndim == 2)) & (values.shape == self.spec_img.shape) ): raise ValueError("Invalid shape of the input values.") if values_err is not None: if values.shape != values_err.shape: raise ValueError("Values and errors shape mismatch.") Y = jnp.array(values) if values_err is None: Yerr = jnp.zeros_like(Y) # msgs.warning("No error is provided. Assuming the errors are zeros.") else: Yerr = jnp.array(values_err) self.Y = jnp.where(jnp.isfinite(Yerr), Y, jnp.nan) self.Yerr = jnp.where(jnp.isfinite(Yerr), Yerr, jnp.nan) # Flatten the values and errors for GP if self.Y.ndim == 1: self.y = self.Y.copy() self.yerr = self.Yerr.copy() elif self.Y.ndim == 2: self.y = self.Y.ravel() self.yerr = self.Yerr.ravel() else: raise ValueError("Y shape error")
[docs] @msgs.timer def sigma_clip( self, sigma: Optional[float] = None, clip_cr: bool = False, batch_idx: Optional[list | tuple[list, list]] = None, ) -> "SpecWrapper": """ Sigma clipping for the spectrum. Parameters ---------- sigma : float, optional Sigma clipping threshold. Default is 5. clip_cr : bool, optional Whether to clip cosmic rays only (i.e., positive outliers). Default is False. batch_idx : list | tuple[list, list], optional Batch indices for sigma clipping. Default is None. Returns ------- SpecWrapper The clipped spectrum. """ if sigma is None: return self if self.Y is None: raise ValueError("sigma_clip requires non-empty spectra.") Y_target = jnp.array(self.Y) masked_init = ~jnp.isfinite(self.Y) Yerr_target = jnp.array(self.Yerr) if batch_idx is None: if self.Y.ndim == 1: # Calculate the means and standard deviations over the entire spectrum batch_idx = ([jnp.arange(self.shape[0])],) else: # Calculate the means and standard deviations at each wavelength (for all spatial pixels) batch_idx = ( [jnp.arange(self.shape[0])], [jnp.atleast_1d(i) for i in jnp.arange(self.shape[1])], ) if self.Y.ndim == 1: for spec_idx in batch_idx[0]: Y_clipped, Yerr_clipped = _clip( self.Y[spec_idx], self.Yerr[spec_idx], sigma=sigma, clip_cr=clip_cr, ) Y_target = Y_target.at[spec_idx].set(Y_clipped) Yerr_target = Yerr_target.at[spec_idx].set(Yerr_clipped) else: for spat_idx in batch_idx[0]: for spec_idx in batch_idx[1]: if (spat_idx.ndim == 1) & (spec_idx.ndim == 1): # Both spat_idx and spec_idx are lists Y_clipped, Yerr_clipped = _clip( self.Y[spat_idx, :][:, spec_idx].ravel(), self.Yerr[spat_idx, :][:, spec_idx].ravel(), sigma=sigma, clip_cr=clip_cr, ) Y_target = Y_target.at[jnp.ix_(spat_idx, spec_idx)].set( Y_clipped.reshape(len(spat_idx), len(spec_idx)) ) Yerr_target = Yerr_target.at[jnp.ix_(spat_idx, spec_idx)].set( Yerr_clipped.reshape(len(spat_idx), len(spec_idx)) ) else: # Either spat_idx or spec_idx is a scalar Y_clipped, Yerr_clipped = _clip( self.Y[spat_idx, :][:, spec_idx], self.Yerr[spat_idx, :][:, spec_idx], sigma=sigma, clip_cr=clip_cr, ) Y_target = Y_target.at[(spat_idx, spec_idx)].set(Y_clipped) Yerr_target = Yerr_target.at[(spat_idx, spec_idx)].set( Yerr_clipped ) masked_final = ~jnp.isfinite(Y_target) msgs.info(f"Sigma clipped {masked_final.sum() - masked_init.sum()} pixels") return SpecWrapper( points=(self.spat, self.spec), values=Y_target, values_err=Yerr_target )
[docs] @msgs.timer def fill_nan(self) -> "SpecWrapper": """ Fill the NaN values in the spectrum by interpolation. """ from scipy.interpolate import griddata if jnp.all(jnp.isfinite(self.Y)): return self if self.Y is None: raise ValueError("Filling NaN requires non-empty spectra.") Y = jnp.array(self.Y) Yerr = jnp.array(self.Yerr) valid = ~jnp.isnan(Y) if jnp.all(valid): return self # If no values to interpolate, return the original spectrum if jnp.sum(valid) == 0: return self x, y = jnp.indices(self.shape) # Interpolate Y_filled = griddata((x[valid], y[valid]), Y[valid], (x, y), method="linear") Y_err_filled = griddata( (x[valid], y[valid]), Yerr[valid], (x, y), method="linear" ) filled_count = jnp.sum(~valid) # Ideally use a logging mechanism here msgs.info(f"Filled {filled_count} NaN pixels") return SpecWrapper( points=(self.spat, self.spec), values=Y_filled, values_err=Y_err_filled )
[docs] def marginalize( self, margin_type: str = "mean", weights: str | Optional[ArrayLike] = None, sigma_clip: float = 5.0, nan_threshold: float = 0.1, ) -> "SpecWrapper": """ Marginalize the 2D spectrum along the spatial axis to obtain the 1D spectrum. Parameters ---------- margin_type : str, optional Type of the marginalization: mean or sum. Default is mean. weights : str, optional Weights for the marginalization: None, ivar, snr, or an array of the weights. Default is None. None: no weights ivar: inverse variance snr: signal-to-noise ratio squared sigma_clip : float, optional Sigma clipping threshold for the marginalization. Default is 5. nan_threshold : float, optional Threshold for the fraction of NaN values in a column. Columns with NaN fraction > nan_threshold will be masked. Default is 0.1. Returns ------- SpecWrapper The marginalized 1D spectrum. """ if self.Y is None: raise ValueError("Marginalizing requires non-empty spectra.") if (weights is None) or jnp.all(self.Yerr == 0): w = jnp.ones_like(self.Y) elif isinstance(weights, (ArrayLike, Array)): if weights.ndim < self.Y.ndim: # Broadcasting the weights to the same shape as the spectrum weights = jnp.tile(weights[:, None], reps=self.Y.shape[1]) if weights.shape != self.Y.shape: raise ValueError( f"Input weights shape {weights.shape} does not match the spectrum shape {self.Y.shape}" ) w = jnp.array(weights) elif weights == "ivar": w = self.Yerr**-2 elif weights == "snr": w = (self.Y / self.Yerr) ** 2 else: raise ValueError("Invalid weights.") # Calculate the overall means and standard deviations Y_meds = jnp.nanmedian(self.Y, axis=0) Y_stds = jnp.nanstd(self.Y, axis=0, ddof=1) # Create the mask for sigma clipping # Broadcasting to compare each column with its own mean and std deviations = jnp.abs(self.Y - Y_meds[None, :]) sigma_masks = deviations <= (sigma_clip * Y_stds[None, :]) valid_masks = jnp.isfinite(self.Y) combined_mask = sigma_masks & valid_masks # Calculate weighted means weights = jnp.where(combined_mask, w, 0) weighted_values = jnp.where(combined_mask, self.Y * weights, 0) mean_value = jnp.sum(weighted_values, axis=0) / jnp.sum(weights, axis=0) # Calculate errors weighted_errors = jnp.where(combined_mask, (self.Yerr * weights) ** 2, 0) mean_value_err = jnp.sqrt( jnp.sum(weighted_errors, axis=0) / jnp.sum(weights, axis=0) ** 2 ) # Mask columns with NaN fraction > nan_threshold nan_fraction = jnp.sum(~jnp.isfinite(self.Y), axis=0) / self.shape[0] mean_value = jnp.where(nan_fraction > nan_threshold, jnp.nan, mean_value) mean_value_err = jnp.where( nan_fraction > nan_threshold, jnp.nan, mean_value_err ) if margin_type == "mean": return SpecWrapper( points=self.spec, values=mean_value, values_err=mean_value_err ) elif margin_type == "sum": return SpecWrapper( points=self.spec, values=mean_value * self.shape[0], values_err=mean_value_err * self.shape[0], )
[docs] def subtract(self, other: "SpecWrapper") -> "SpecWrapper": """ Subtract another spectrum. Parameters ---------- other : SpecWrapper The other spectrum to be subtracted. Returns ------- SpecWrapper The subtracted spectrum. """ if self.Y is None: raise ValueError("Subtraction requires non-empty spectra.") if ((len(other.shape) == 1) & (other.shape[-1] != self.shape[-1])) | ( (len(other.shape) == 2) & (other.shape != self.shape) ): raise ValueError("Shape mismatch.") return SpecWrapper( points=(self.spat, self.spec) if self.spat is not None else self.spec, values=self.Y - other.Y, values_err=(self.Yerr**2 + other.Yerr**2) ** 0.5, )
[docs] def apply_spatial_filter(self, spat_filter: ArrayLike) -> "SpecWrapper": """ Creates a new spectrum by applying a spatial mask to the current spectrum. Parameters ---------- spat_filter : ArrayLike Filter to apply to the spatial axis. Returns ------- SpecWrapper A new spectrum containing only the data points selected by the mask """ return SpecWrapper( points=(self.spat[spat_filter], self.spec), values=self.Y[spat_filter], values_err=self.Yerr[spat_filter], )
[docs] def convolve(self, kernel_wid: float | ArrayLike) -> "SpecWrapper": """ Convolve the spectrum with a kernel. Parameters ---------- kernel_wid : float | ArrayLike The width (FWHM in pixel) of the Gaussian kernel Returns ------- SpecWrapper The seeing-matched spectrum. """ if self.Y is None: raise ValueError("Convolution requires non-empty spectra.") if jnp.all(kernel_wid == 0): msgs.warning("Kernel width is zero. Returning the original spectrum.") return SpecWrapper( points=(self.spat, self.spec), values=self.Y, values_err=self.Yerr ) def gaussian_filter(y: Array, sigma: Array) -> Array: """ Vectorized 1D Gaussian filter """ @partial(jax.jit, static_argnums=(2,)) def jax_gaussian_filter1d( y: Array, sigma: float, max_kernel_size=30 ) -> Array: """ JAX implementation of 1D Gaussian filter """ # Create Gaussian kernel radius = max_kernel_size // 2 x = jnp.arange(-radius, radius + 1) kernel = jnp.exp(-0.5 * (x / sigma) ** 2) kernel = kernel / jnp.sum(kernel) # Pad and convolve pad_width = len(kernel) // 2 y_padded = jnp.pad(y, (pad_width, pad_width), mode="constant") return jnp.convolve(y_padded, kernel, mode="valid") return jax.vmap(jax_gaussian_filter1d, in_axes=(1, 0), out_axes=1)(y, sigma) if isinstance(kernel_wid, float): kernel_sigma = jnp.ones_like(self.spec) * kernel_wid / 2.355 elif isinstance(kernel_wid, ArrayLike): if kernel_wid.size != self.spec.size: raise ValueError( f"The length of the kernel_wid array ({kernel_wid.size}) does not match the spectrum ({self.spec.size})." ) kernel_sigma = kernel_wid / 2.355 else: raise ValueError("Invalid kernel width.") Y_conv = gaussian_filter(self.Y, kernel_sigma) Yerr_conv = gaussian_filter(self.Yerr, kernel_sigma) return SpecWrapper( points=(self.spat, self.spec), values=Y_conv, values_err=Yerr_conv )
[docs] @partial(jax.jit, static_argnames=["axis"]) def mad_std(data: Array, axis: Optional[int] = None) -> Array: """ Compute the Median Absolute Deviation (MAD) of the input data. """ median = jnp.nanmedian(data, axis=axis, keepdims=True) return jnp.nanmedian(jnp.abs(data - median), axis=axis, keepdims=True) * 1.4826
@partial(jax.jit, static_argnames=["clip_cr"]) def _clip(Y: Array, Yerr: Array, sigma: float, clip_cr: bool) -> tuple[Array, Array]: """ Sigma clipping for a batch of the spectrum. """ if Y.ndim == 1 and Yerr.ndim == 1: Y_meds = jnp.nanmedian(Y, keepdims=True) Y_stds = mad_std(Y) elif Y.ndim == 2 and Yerr.ndim == 2: # For 2D spectra, calculate median and MAD along the spatial axis Y_meds = jnp.nanmedian(Y, axis=1, keepdims=True) Y_stds = mad_std(Y, axis=1) else: raise ValueError( f"Invalid shape of the input data: Y.shape={Y.shape}, Yerr.shape={Yerr.shape}" ) if clip_cr: # Only remove positive outliers sigma_mask = ((Y - Y_meds) <= (sigma * Y_stds)) & jnp.isfinite(Y) else: # Remove both positive and negative outliers sigma_mask = (jnp.abs(Y - Y_meds) <= (sigma * Y_stds)) & jnp.isfinite(Y) Y_clipped = jnp.where(sigma_mask, Y, jnp.nan) Yerr_clipped = jnp.where(sigma_mask, Yerr, jnp.nan) return Y_clipped, Yerr_clipped