Source code for hostsub_gp.host_model

# hostsub_gp/host_model.py

__all__ = ["HostProfile"]

import numpy as np

import jax
import jax.numpy as jnp

# jax.config.update("jax_enable_x64", True)


from .gp import GP
from .host_image import PS1Image, SDSSImage, LSImage
from .host_image import ImageProduct
from ._utils import plt, msgs
from ._utils._plt import show_and_save

from typing import Callable, Optional, Literal, Type, Any
from jax._src.typing import Array
from .spec_base import SpecModelP


[docs] class HostProfile: def __init__( self, filters: str | list[str], wv_eff: list[float], spat_slit: list[Array], counts_slit: list[Array], counts_err_slit: list[Array], spec_model: Optional[SpecModelP] = None, slit_len: Optional[float] = None, pixel_scale: Optional[float] = None, ): """ Estimate the host galaxy spatial profile from the 2D spectrum. Parameters ---------- flts : str Filters to load the images. wv_eff : list[float] Effective wavelengths. spat_slit : list[Array] Spatial coordinates along the slit. counts_slit : list[Array] Counts along the slit. counts_err_slit : list[Array] Errors of the counts along the slit. spec_model : SpecModelP, optional SpecModel object. slit_len : float, optional Slit length in arcsec. pixel_scale : float, optional Pixel scale in arcsec. """ self.filters = filters self.wv_eff = wv_eff prof_slit, prof_err_slit = [], [] if spec_model is not None: slit_len = ( spec_model.spat_edges["slit"][1] - spec_model.spat_edges["slit"][0] ) pixel_scale = spec_model.pixel_scale mask_offset = spec_model.mask_offset elif slit_len is None or pixel_scale is None: raise ValueError("Slit length and pixel scale are required") else: mask_offset = 0 for k in range(len(self.filters)): if spec_model is not None: host_left = ( spec_model.spat_edges["host"][0], spec_model.spat_edges["mask"][0], ) host_right = ( spec_model.spat_edges["mask"][1], spec_model.spat_edges["host"][1], ) sky_left = ( spec_model.spat_edges["slit"][0], max( spec_model.spat_edges["sky"][0], spec_model.spat_edges["slit"][0], ), ) sky_right = ( min( spec_model.spat_edges["sky"][1], spec_model.spat_edges["slit"][1], ), spec_model.spat_edges["slit"][1], ) xi = counts_slit[k] xi_err = counts_err_slit[k] xi_sky_mean = ( bound_sum(spat_slit[k], xi, x_bound=sky_left) + bound_sum(spat_slit[k], xi, x_bound=sky_right) ) / ((sky_left[1] - sky_left[0]) + (sky_right[1] - sky_right[0])) xi_host_mean = ( bound_sum(spat_slit[k], xi, x_bound=host_left) + bound_sum(spat_slit[k], xi, x_bound=host_right) ) / ((host_left[1] - host_left[0]) + (host_right[1] - host_right[0])) prof_slit.append( (xi - xi_sky_mean) / (xi_host_mean - xi_sky_mean) # / (spec_model.host_wid - spec_model.mask_wid) / ((host_left[1] - host_left[0]) + (host_right[1] - host_right[0])) * pixel_scale ) prof_err_slit.append( xi_err / (xi_host_mean - xi_sky_mean) # /(spec_model.host_wid - spec_model.mask_wid) / ((host_left[1] - host_left[0]) + (host_right[1] - host_right[0])) * pixel_scale ) else: # No mask host_left = (-slit_len / 2, 0) host_right = (0, slit_len / 2) xi = counts_slit[k] / np.sum(counts_slit[k]) / pixel_scale xi_err = counts_err_slit[k] / np.sum(counts_slit[k]) / pixel_scale prof_slit.append(xi) prof_err_slit.append(xi_err) # trim the slit if spec_model is not None: self.host_wid = ( spec_model.spat_edges["host"][1] - spec_model.spat_edges["host"][0] ) # Host width in pixels else: self.host_wid = slit_len # Host width in pixels - if not specified, using the slit length host_idx = [ np.argwhere( (spat_slit[k] >= host_left[0]) & (spat_slit[k] <= host_right[1]) ).ravel() for k in range(len(self.filters)) ] # Drop filters with NaN values and trim arrays accordingly valid_idx = [ k for k in range(len(self.filters)) if np.all(np.isfinite(prof_slit[k])) and np.all(np.isfinite(prof_err_slit[k])) ] if len(valid_idx) < len(self.filters): msgs.warning( f"Dropping {len(self.filters) - len(valid_idx)} filters with NaN values in the host profile estimation." ) self.prof_slit = [prof_slit[k][host_idx[k]] for k in valid_idx] self.prof_err_slit = [prof_err_slit[k][host_idx[k]] for k in valid_idx] self.spat_slit = [spat_slit[k][host_idx[k]] for k in valid_idx] self.wv_slit = [ np.full_like(host_idx[k], self.wv_eff[k], dtype=float) for k in valid_idx ] self.filters = [self.filters[k] for k in valid_idx] self.wv_eff = [self.wv_eff[k] for k in valid_idx] self.prof = jnp.concatenate(self.prof_slit) self.prof_err = jnp.concatenate(self.prof_err_slit) self.X = jnp.stack( [jnp.concatenate(self.spat_slit), jnp.concatenate(self.wv_slit)], axis=-1 ) @staticmethod def _suppress_fitsfixed_warning(func): """ Decorator to suppress FITSFixedWarning in the decorated function. """ import warnings from functools import wraps from astropy.wcs import FITSFixedWarning @wraps(func) def wrapper(*args, **kwargs): # Suppress FITSFixedWarning with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FITSFixedWarning) return func(*args, **kwargs) return wrapper
[docs] @staticmethod @_suppress_fitsfixed_warning def load_archival_images( spec_model: Optional[SpecModelP] = None, center_ra: Optional[float] = None, center_dec: Optional[float] = None, slit_len: Optional[float] = None, slit_wid: float = 1.0, position_angle: Optional[float] = None, filters: Optional[str | list] = None, survey: Literal["PS1", "LS", "any"] = "PS1", ) -> list[ImageProduct]: """ Load, rotate, and resample archival images from PS1, LS, and SDSS """ from operator import attrgetter if spec_model is not None: center_ra = spec_model.center_ra center_dec = spec_model.center_dec slit_len = ( spec_model.spat_edges["slit"][1] - spec_model.spat_edges["slit"][0] ) slit_wid = spec_model.slit_wid position_angle = spec_model.position_angle else: if center_ra is None or center_dec is None: raise ValueError("Coordinates are required") if position_angle is None: raise ValueError("Position angle is required") if filters is None: # Load all filters filters = "grizy" # Data img_product_list = [] def _load_images( image_class: Type[SDSSImage | LSImage | PS1Image], filters: str, center_ra: float, center_dec: float, ) -> bool: """ Load images from the specified image class. """ image = image_class(ra=center_ra, dec=center_dec, filters=filters) image.get_cutout() imgs, headers = image.load() img_products = [] for k in range(len(image.filters)): img, header = imgs[k], headers[k] flt = image.filters[k] wv_eff = image.wv_eff_dict[flt] img_products.append( ImageProduct( center_ra=center_ra, center_dec=center_dec, slit_len=slit_len, slit_wid=slit_wid, position_angle=position_angle, img=img, header=header, flt=flt, wv_eff=wv_eff, ) ) # No images found if len(img_products) == 0: return False img_product_list.extend(img_products) return True # Try to load SDSS u-band image _load_images( SDSSImage, "".join(flt for flt in "u" if flt in filters), center_ra, center_dec, ) # Try LS images first if requested if survey == "LS" or survey == "any": _ls_loaded = _load_images( LSImage, "".join(flt for flt in "griz" if flt in filters), center_ra, center_dec, ) if not _ls_loaded: raise ValueError("No LS images found") if (survey == "PS1") or ((survey == "any") and not _ls_loaded): _ps1_loaded = _load_images( PS1Image, "".join(flt for flt in "grizy" if flt in filters), center_ra, center_dec, ) if not _ps1_loaded: raise ValueError("No PS1 images found") # TODO: Load acquisition images (optional) return sorted(img_product_list, key=attrgetter("wv_eff"))
[docs] @classmethod @_suppress_fitsfixed_warning def from_archival( cls, img_products: list[ImageProduct], spec_model: Optional[SpecModelP] = None, slit_len: Optional[float] = None, slit_wid: Optional[float] = None, pixel_scale: Optional[float] = None, dseeing: Optional[float] = None, alpha: Optional[float] = 0.2, verbose: bool = False, ): """ Load archival images from PS1 and SDSS and estimate the host galaxy spatial profile. Parameters ---------- spec_model : any, optional SpecModel object. center_ra : float, optional Right ascension of the object. center_dec : float, optional Declination of the object. slit_len : float, optional Slit length in arcsec. slit_wid : float, optional Slit width in arcsec. pixel_scale : float, optional Pixel scale in arcsec of the 2D spectrum (not the archival images) position_angle : float, optional Position angle of the slit. filters : str or list, optional Filters to load the images. survey : str, optional Survey to use for loading images. Options are 'PS1' or 'LS'. dseeing : float, optional alpha: float, optional, default = 0.2 """ from scipy.ndimage import gaussian_filter if spec_model is not None: slit_wid = spec_model.slit_wid slit_len = ( spec_model.spat_edges["slit"][1] - spec_model.spat_edges["slit"][0] ) pixel_scale = spec_model.pixel_scale else: assert slit_len is not None and slit_wid is not None assert pixel_scale is not None # Seeing correction if dseeing is None: dseeing = 0.0 else: dseeing = np.abs(dseeing) # By default, dseeing is negative assert isinstance(dseeing, float), "dseeing is not properly assigned" # Spatial coordinates along the slit spat_slit = [img_product.spat_slit for img_product in img_products] # Counts along the slit counts_slit, counts_err_slit = [], [] for img_product in img_products: if dseeing > 0: # Convolve the images with a Gaussian kernel assert spec_model is not None, ( "SpecModel is required for dseeing correction" ) dseeing_wv = dseeing * ( img_product.wv_eff / spec_model.spec.mean() ) ** (-alpha) img = gaussian_filter( img_product.img, sigma=dseeing_wv / spec_model.pixel_scale / 2.355 ) if verbose: msgs.info( f"Convolving {img_product.flt} with a {dseeing_wv:.2f} arcsec kernel " + f"(sigma = {dseeing_wv / spec_model.pixel_scale / 2.355:.2f} pixels)" ) else: img = img_product.img counts_slit.append( bound_mean_img( img_product.spat_slit_wid, img, x_bound=(-slit_wid / 2, slit_wid / 2), ) ) # Estimate the error: standard deviation of the residuals (count at each pixel - average count) img_slit = img_product.img[ :, np.abs(img_product.spat_slit_wid) < slit_wid / 2 ] # slit_wid_pix = int(np.round(slit_wid / img_product.pixel_scale)) err1 = np.nanstd(img_slit, axis=1, ddof=1) # / np.sqrt(slit_wid_pix) # Smooth the error: convolution with a boxcar filter noise_smooth_kernel = 3 if noise_smooth_kernel is not None: err1 = ( np.convolve( err1**2, np.ones(noise_smooth_kernel) / noise_smooth_kernel, mode="same", ) ) ** 0.5 err2 = np.ones_like(counts_slit[-1]) * img_product.err counts_err_slit.append(np.where(err1 < err2, err1, err2)) flts = [img_product.flt for img_product in img_products] wv_effs = [img_product.wv_eff for img_product in img_products] return cls( filters=flts, wv_eff=wv_effs, spat_slit=spat_slit, counts_slit=counts_slit, counts_err_slit=counts_err_slit, spec_model=spec_model, slit_len=slit_len, pixel_scale=pixel_scale, )
[docs] @msgs.timer def model_host_profile_prior( self, spat_resln: float = 1.0, params_init: Optional[dict] = None, **kwargs ) -> Callable[[Any], Array | tuple[Array, Array]]: """ Model the host galaxy spatial profile using Gaussian Process regression. Parameters ---------- spat_resln : float, optional Spatial resolution (seeing) in arcsec. Returns ------- host_prior : Callable[[Array], tuple[Array, Array]] A function that returns the mean and variance of the host profile. """ if params_init is not None: assert "log_scale" in params_init, "log_scale is required in params_init" assert "log_amp" in params_init, "log_amp is required in params_init" assert "mean" in params_init, "mean is required in params_init" # No prior photometric data if len(self.filters) == 0: def host_prior_flat( x: Array, ) -> tuple[Array, Array]: # constant, variance = 0 return jnp.array(1 / self.host_wid, dtype=jnp.float32), jnp.array( 0, dtype=jnp.float32 ) self._gp_params = {} host_prior = host_prior_flat # Single band - no wavelength dependence elif len(self.filters) == 1: if params_init is None: params_init = dict( log_amp=np.float64(-3), log_scale=np.log10(spat_resln), mean=np.float64(1 / self.host_wid), ) params_limit = dict(log_scale=np.log10([spat_resln / 2.355, np.inf])) gp_host_prior = GP( kernel_type="HostProfie", X=self.X[:, :1], # Spatial coordinate only y=self.prof, yerr=self.prof_err, params_init=params_init, params_limit=params_limit, optimization=True, ) self._gp_params = gp_host_prior.params def host_prior_single(x: Array) -> tuple[Array, Array]: return gp_host_prior.predict(X_test=x[:, :1], return_var=True) host_prior = host_prior_single # Multiple bands - wavelength dependence else: if params_init is None: params_init = dict( log_amp=np.ones((2, 2)) * -2, log_scale=np.log10([[spat_resln, spat_resln], [1e5, 1e4]]), mean=np.float64(1 / self.host_wid), ) params_limit = dict( log_scale=np.log10( [ [ [spat_resln / 2.355, spat_resln / 2.355], [1e3, 1e3], ], # lower bound [ [np.inf, spat_resln * 2], [np.inf, np.inf], ], # upper bound ] ) ) gp_host_prior = GP( kernel_type="HostProfile", X=self.X, y=self.prof, yerr=self.prof_err, params_init=params_init, params_limit=params_limit, optimization=True, ) self._gp_params = gp_host_prior.params def host_prior_multi(x: Array) -> tuple[Array, Array]: return gp_host_prior.predict(X_test=x, return_var=True) host_prior = host_prior_multi self._plot_host_profile(host_prior, **kwargs) return host_prior
@show_and_save def _plot_host_profile(self, host_prior) -> None: """ Plot the host galaxy spatial profile. """ from matplotlib.colors import Normalize _, ax = plt.subplots( len(self.filters), 1, figsize=(10, 3 * len(self.filters)), sharex=True, sharey=True, constrained_layout=True, ) ax = np.atleast_1d(ax) cmap = plt.cm.get_cmap("coolwarm") norm = Normalize(vmin=0, vmax=len(self.filters) - 1) for k in range(len(self.filters)): ax[k].plot( self.spat_slit[k], self.prof_slit[k], label=f"{self.filters[k]}", color=cmap(norm(k)), ) ax[k].plot( self.spat_slit[k], host_prior(jnp.stack([self.spat_slit[k], self.wv_slit[k]], axis=-1))[0], "--", color=cmap(norm(k)), ) ax[k].fill_between( self.spat_slit[k], self.prof_slit[k] - self.prof_err_slit[k], self.prof_slit[k] + self.prof_err_slit[k], color=cmap(norm(k)), alpha=0.2, ) ax[k].set_ylabel(r"$\mathrm{Profile}$") ax[k].text( 0.05, 0.8, f"{self.filters[k]}: {self.wv_eff[k]:.0f} Ang", color=cmap(norm(k)), transform=ax[k].transAxes, ) ax[-1].set_xlabel(r"$\mathrm{Spat\ [arcsec]}$")
def bound_sum(x: Array, y: Array, x_bound: Optional[tuple] = None) -> float: """ Compute the mean values in a bounded region. """ bin_size = jnp.append(x[1] - x[0], jnp.diff(x)) if x_bound is None: x_bound = (x[0] - bin_size[0] / 2, x[-1] + bin_size[-1] / 2) if x_bound[1] <= x_bound[0]: return jnp.float32(0) # sum up all pixels that are fully contained in the region idx_center = (x > x_bound[0] + bin_size[0] / 2) & ( x < x_bound[1] - bin_size[-1] / 2 ) sum_center = jnp.sum(y[idx_center] * bin_size[idx_center]) # leftmost pixel that is partially contained in the region (if any) idx_left = jnp.where(x >= x_bound[0] - bin_size[0] / 2)[0] if idx_left.size > 0: y_left = y[idx_left[0]] frac_left = x[idx_left[0]] - (x_bound[0] - bin_size[0] / 2) sum_left = y_left * frac_left else: raise ValueError("Invalid left bound") # rightmost pixel that is partially contained in the region (if any) idx_right = jnp.where(x <= x_bound[1] + bin_size[-1] / 2)[-1] if idx_right.size > 0: y_right = y[idx_right[-1]] frac_right = (x_bound[1] + bin_size[-1] / 2) - x[idx_right[-1]] sum_right = y_right * frac_right else: raise ValueError("Invalid right bound") return sum_center + sum_left + sum_right def bound_mean(x: Array, y: Array, x_bound: tuple) -> float: """ Compute the sum in a bounded region. """ return bound_sum(x, y, x_bound) / (x_bound[1] - x_bound[0]) def bound_mean_img(x: Array, y_img: Array, x_bound: tuple) -> Array: """ Apply bound_mean to each row of the image. """ return jax.vmap(lambda y: bound_mean(x, y, x_bound))(y_img)