Source code for hostsub_gp.spec_proc

# hostsub_gp/spec_proc.py

__all__ = ["SpecData"]

import numpy as np
import jax
import jax.numpy as jnp
import os

# jax.config.update("jax_enable_x64", True)

from jax._src.typing import ArrayLike, Array
from typing import Callable, Optional, Literal

from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy import units as u

from .interp import Interp1D_Grid, Interp2D_Grid
from .interp import Interp2D_Scipy, Interp2D_RBF, Interp2D_Linear
from .spec_model import SpecModel
from .host_model import HostProfile
from ._utils import plt, msgs
from ._utils._plt import show_and_save


[docs] class SpecData: """ A class to load (from different spectrograph) and preprocess (rectification & alignment) the 2D spectrum. """ def __init__( self, pixel_scale: float, center_ra: float, center_dec: float, slit_wid: float, position_angle: float, spat_resln: float, spec_resln: float, spat_rect: ArrayLike, spec_rect: ArrayLike, flux_rect: Optional[Array] = None, flux_ivar_rect: Optional[Array] = None, dist: Optional[Array] = None, waveimg: Optional[Array] = None, flux: Optional[Array] = None, flux_global_sky: Optional[Array] = None, flux_ivar: Optional[Array] = None, sky_offset: Optional[float] = None, survey: Literal["PS1", "LS"] = "PS1", to_caches: Optional[bool] = False, cache_path: Optional[str] = None, ): assert np.all(spat_rect[1:] > spat_rect[:-1]), ( "Spatial coordinates must be in ascending order." ) assert np.all(spec_rect[1:] > spec_rect[:-1]), ( "Spectral coordinates must be in ascending order." ) self.spat_rect = jnp.asarray(spat_rect) self.spec_rect = jnp.asarray(spec_rect) self.pixel_scale = pixel_scale self.center_ra = center_ra self.center_dec = center_dec self.slit_wid = slit_wid self.position_angle = position_angle self.spat_resln = spat_resln self.spec_resln = spec_resln self.sky_offset = sky_offset if (flux_rect is not None) and (flux_ivar_rect is not None): self.flux_rect = jnp.asarray(flux_rect) self.flux_ivar_rect = jnp.asarray(flux_ivar_rect) else: assert dist is not None, "Distance array is not available." assert flux is not None and flux_ivar is not None, ( "Flux and ivar arrays are not available." ) if self.sky_offset is None: slit_len = ( min( self.spat_rect.max().item() - self.spat_rect.min().item(), 30 // self.pixel_scale * self.pixel_scale, ) + self.pixel_scale ) img_products = HostProfile.load_archival_images( center_ra=self.center_ra, center_dec=self.center_dec, slit_len=slit_len, slit_wid=self.slit_wid, position_angle=self.position_angle, survey=survey, ) host_prior = HostProfile.from_archival( img_products=img_products, slit_len=slit_len, slit_wid=self.slit_wid, pixel_scale=self.pixel_scale, ).model_host_profile_prior() self.sky_offset = self._get_offset( points=jnp.stack([dist, waveimg], axis=-1), flux=flux, show=True, mask_wid=2.0, host_prior=host_prior, ) self._points = jnp.stack([dist - self.sky_offset, waveimg], axis=-1) # Rectify the 2D spectrum if flux_global_sky is not None: self.flux_rect, self.flux_ivar_rect = self._rectify( points=self._points, f_values=(flux - flux_global_sky, flux_ivar), interp_method="scipy", ) else: self.flux_rect, self.flux_ivar_rect = self._rectify( points=self._points, f_values=(flux, flux_ivar), interp_method="scipy", ) # Save the 2D spectra to cache files self.cache_path = cache_path if to_caches: self.to_fits()
[docs] @classmethod def from_pypeit( cls, sci_file: str, raw_dir: Optional[str] = None, obj_id: Optional[str] = None, std_file: Optional[str] = None, ra: Optional[float] = None, dec: Optional[float] = None, spat_resln: Optional[float] = None, spat_rect: Optional[ArrayLike] = None, spec_rect: Optional[ArrayLike] = None, **kwargs, ): """ Load 2D spectra from PypeIt output files. Parameters ---------- sci_file : str The filename of the science object. obj_id : str, optional (default: None) The object ID in the science frame. std_file : str, optional (default: None) The filename of the standard star. raw_dir : str, optional (default: None) The directory of the raw files (not needed for LRIS data). ra, dec : float, optional (default: None) The RA and DEC of the science object. If not provided, the RA and DEC in the header will be used. spat_resln : float, optional (default: None) The spatial resolution (seeing) of the science frame. spat_rect : ArrayLike, optional (default: None) The spatial coordinates of the rectified 2D spectrum. spec_rect : ArrayLike, optional (default: None) The spectral coordinates of the rectified 2D spectrum. """ from pypeit import spec2dobj, specobjs if "spec1d" in sci_file: spec1d_file = sci_file spec2d_file = sci_file.replace("spec1d", "spec2d") elif "spec2d" in sci_file: spec2d_file = sci_file spec1d_file = sci_file.replace("spec2d", "spec1d") else: raise ValueError("Incorrect sci_file format.") msgs.info(f"Loading 2D spectrum for {spec2d_file}...") pypeit_header = fits.getheader(spec2d_file) # Access the header of the raw image raw_file = pypeit_header["FILENAME"] # Items to load # - RA, DEC (in degrees) # - detector ID # - pixel scale (in arcsec/pixel) # - position angle (in degrees) # - slit width (in arcsec) def _get_raw_header( raw_dir: str | None, raw_file: str, **kwargs ) -> fits.Header: """ Get the header of the raw file. """ if raw_dir is None: raise ValueError( f"The raw file directory is needed for {pypeit_header['PYP_SPEC']} data." ) return fits.getheader("/".join([raw_dir, raw_file]), **kwargs) bad_rows = [] if pypeit_header["PYP_SPEC"] in [ "keck_lris_blue", "keck_lris_red", "keck_lris_red_mark4", ]: # Keck/LRIS if ra is None or dec is None: ra, dec = pypeit_header["RA"], pypeit_header["DEC"] if "red_mark4" in pypeit_header["PYP_SPEC"]: det = "DET01" else: det = "DET02" binning = int( pypeit_header["BINNING"].split(",")[1] ) # in the spatial direction pixel_scale = 0.135 * binning position_angle = pypeit_header["ROTPOSN"] + 90 slit_wid = float(pypeit_header["SLITNAME"].split("_")[-1]) elif pypeit_header["PYP_SPEC"] == "mmt_binospec": # MMT/Binospec raw_header = _get_raw_header(raw_dir, raw_file, ext=1) if ra is None or dec is None: # RA and DEC in the header are in the format of 'HH:MM:SS.SS' and 'DD:MM:SS.SS' ra_str, dec_str = ( str(raw_header["CATRA"]).strip("'"), str(raw_header["CATDEC"]).strip("'"), ) coord = SkyCoord(ra_str, dec_str, unit=(u.hourangle, u.deg)) ra, dec = coord.ra.deg, coord.dec.deg det = "DET02" pixel_scale = 0.24 # PA: parallactic angle # ROT: instrument rotator angle (relative to the parallactic angle) position_angle = raw_header["PA"] - raw_header["ROT"] slit_wid = float(raw_header["MASK"].split("Longslit")[-1]) # MMT faulty rows to be removed bad_rows = [113, 210, 719, 1999, 2099, 3336, 3337, 4056, 4057] + [1717, 3011] elif pypeit_header["PYP_SPEC"] == "not_alfosc": # NOT/ALFOSC raw_header = _get_raw_header(raw_dir, raw_file, ext=0) if ra is None or dec is None: ra, dec = pypeit_header["RA"], pypeit_header["DEC"] det = "DET01" binning = int( pypeit_header["BINNING"].split(",")[1] ) # in the spatial direction pixel_scale = 0.2138 * binning position_angle = raw_header["FIELD"] + 180 # Slit width format: "Slit_1.3" slit_wid = float(pypeit_header["DECKER"].split("_")[-1]) elif pypeit_header["PYP_SPEC"] == "vlt_fors2": # VLT/FORS2 raw_header = _get_raw_header(raw_dir, raw_file, ext=0) if ra is None or dec is None: ra, dec = pypeit_header["RA"], pypeit_header["DEC"] det = "DET01" pixel_scale = 0.25 position_angle = -raw_header["HIERARCH ESO ADA POSANG"] + 180 # Slit width format: "Slit1_0arcsec" slit_wid = float( ".".join( pypeit_header["DECKER"] .split("Slit")[-1] .split("arcsec")[0] .split("_") ) ) else: raise NotImplementedError( "Only LRIS, Binospec, ALFOSC, and FORS2 are supported" ) # If the object ID in the science frame is provided (i.e., object successfully found), use the object trace if obj_id is not None: trace_file = spec1d_file trace_objs = specobjs.SpecObjs.from_fitsfile(trace_file, det=det) name_idx = trace_objs.name_indices(obj_id) if not any(name_idx): raise ValueError(f"Object {obj_id} not found in the trace file.") trace_obj = trace_objs[np.where(name_idx)[0][0]] # If the object ID is not provided, use the standard star trace elif std_file is not None: trace_file = std_file trace_objs = specobjs.SpecObjs.from_fitsfile(trace_file, det=det) # Find the SpecObj with the highest signal-to-noise ratio (S2N) in the SpecObjs objs2n = [] for obj in trace_objs: assert obj is not None objs2n.append(obj["S2N"]) trace_obj = trace_objs[np.argmax(objs2n)] else: raise ValueError("No spec1d file provided for identifying the trace.") if spat_resln is None: if std_file is None: raise ValueError( "The spatial resolution needs to be either provided in the config file or estimated from the standard." ) std_objs = specobjs.SpecObjs.from_fitsfile(std_file, det=det) objs2n = [] for obj in std_objs: assert obj is not None objs2n.append(obj["S2N"]) trace_obj = std_objs[np.argmax(objs2n)] spat_resln = trace_obj["FWHM"] * pixel_scale assert trace_obj is not None, "Trace object is not available." trace_spat_pix = jnp.asarray( trace_obj["TRACE_SPAT"], dtype=float ) # spatial pixel of the trace sci2d = spec2dobj.Spec2DObj.from_file(spec2d_file, detname=det) # Refill bad rows for row in bad_rows: sci2d.sciimg[row, :] = ( sci2d.sciimg[row - 1, :] + sci2d.sciimg[row + 1, :] ) / 2 sci2d.ivarraw[row, :] = ( sci2d.ivarraw[row - 1, :] + sci2d.ivarraw[row + 1, :] ) / 2 # Spectral resolution: the FWHM of the arc lines try: spec_resln = sci2d["wavesol"]["measured_fwhm"].value[0] except KeyError: spec_resln = sci2d["wavesol"]["mesured_fwhm"].value[ 0 ] # A typo in old PypeIt versions flux = np.array(sci2d.sciimg.T) ivar = np.array(sci2d.ivarraw.T) waveimg = np.array(sci2d.waveimg.T) bpmmask = np.array(sci2d.bpmmask.mask.T) tilts = np.array(sci2d.tilts.T) flux[bpmmask != 0] = np.nan ivar[bpmmask != 0] = 0 # Estimate the distance from the standard trace # For each pixel in the 2D spectrum with a certain wavelength, # find the corresponding spectral pixel within the trace at the same wavelength trace_spec = Interp2D_Grid( points=(jnp.arange(waveimg.shape[0]), jnp.arange(waveimg.shape[1])), values=jnp.asarray(waveimg), )(np.stack([trace_spat_pix, np.arange(waveimg.shape[1])], axis=-1)) assert trace_spec is not None, "Trace spectral pixels are not available." trace_spec_pix = jnp.where( tilts != 0, Interp1D_Grid( points=jnp.asarray(trace_spec), values=jnp.arange(len(trace_spec)) )(waveimg), np.nan, ) # indices of the spatial and spectral pixels spat_pix = jnp.tile(jnp.arange(waveimg.shape[0]), (waveimg.shape[1], 1)).T spec_pix = jnp.tile(jnp.arange(waveimg.shape[1]), (waveimg.shape[0], 1)) dist_spat_pix = spat_pix - trace_spat_pix dist_spec_pix = spec_pix - trace_spec_pix dist = ( np.sqrt(dist_spat_pix**2 + dist_spec_pix**2) * np.where(dist_spat_pix > 0, 1, -1) * pixel_scale ) # Preliminary sky line removal - reduce the noise introduced in the rectification wave_sky = np.nanmedian(np.where(waveimg > 0, waveimg, np.nan), axis=0) flux_sky = np.nanmedian(np.where(waveimg > 0, flux, np.nan), axis=0) global_sky = Interp1D_Grid(points=wave_sky, values=flux_sky, method="cubic")( waveimg ) # Save the 2D spectrum wavelength solution & distance from the trace to a fits file primary_hdu = fits.PrimaryHDU() hdr = primary_hdu.header hdr["RAWFILE"] = spec2d_file hdr["DET"] = det # Save the distance array hdu_dist = fits.ImageHDU(dist, name="DIST") hdu_dist.header["UNIT"] = "arcsec" hdu_dist.header["COMMENT"] = "Distance from the trace" # Save the wavelength solution hdu_waveimg = fits.ImageHDU(waveimg, name="WAVEIMG") hdu_waveimg.header["UNIT"] = "Angstrom" hdu_waveimg.header["COMMENT"] = "Wavelength solution" # Save the global sky background hdu_global_sky = fits.ImageHDU(global_sky, name="GLOBALSKY") hdu_global_sky.header["COMMENT"] = ( "Global sky background (average across the slit)" ) hdul = fits.HDUList([primary_hdu, hdu_dist, hdu_waveimg, hdu_global_sky]) hdul.writeto(spec2d_file.replace(".fits", "_preproc.fits"), overwrite=True) # Copy the spec1d fits file os.system(f"cp {spec1d_file} {spec1d_file.replace('.fits', '_hostsub.fits')}") # Remove spatial pixels outside the slit (all spat values are NaN) valid_spat = jnp.any(np.isfinite(dist), axis=1) # Remove spectral pixels outside the wavelength range (some spec values are NaN) valid_spec = jnp.all(np.isfinite(dist[valid_spat]), axis=1) dist = dist[valid_spat][valid_spec] waveimg = waveimg[valid_spat][valid_spec] flux = flux[valid_spat][valid_spec] ivar = ivar[valid_spat][valid_spec] global_sky = global_sky[valid_spat][valid_spec] # _, ax = plt.subplots(3, 1, figsize=(12, 8), constrained_layout=True, sharex=True, sharey=True) # ax[0].imshow( # flux, # origin="lower", # aspect="auto", # cmap="gray", # vmin=np.nanpercentile(flux, 5), # vmax=np.nanpercentile(flux, 95), # ) # ax[0].set_title("Flux") # ax[1].imshow( # global_sky, # origin="lower", # aspect="auto", # cmap="gray", # vmin=np.nanpercentile(flux, 5), # vmax=np.nanpercentile(flux, 95), # ) # ax[1].set_title("Sky Prediction") # ax[2].imshow( # flux - global_sky, # origin="lower", # aspect="auto", # cmap="gray", # vmin=np.nanpercentile(flux - global_sky, 5), # vmax=np.nanpercentile(flux - global_sky, 95), # ) # ax[2].set_title("Sky Subtracted") # plt.show() if spat_rect is None: # Spatial coordinates: within the range of the slit # Remove the first and last spatial pixels to avoid the edge effects spat_rect = ( jnp.arange( dist[0].max() // pixel_scale, dist[-1].min() // pixel_scale + 1 )[1:-1] * pixel_scale ) msgs.info( f"Distance from the trace: {spat_rect[0]:.2f} - {spat_rect[-1]:.2f} arcsec" ) if spec_rect is None: # Spectral coordinates: at the location of the trace # Remove the first and last spectral pixels to avoid the edge effects assert trace_spec is not None, "Trace spectral pixels are not available." spec_rect = trace_spec[1:-1] msgs.info( f"Wavelength range: {spec_rect[0]:.2f} - {spec_rect[-1]:.2f} Angstrom" ) assert ra is not None, "RA is not available." assert dec is not None, "DEC is not available." assert position_angle is not None, "Position angle is not available." return cls( pixel_scale=pixel_scale, center_ra=ra, center_dec=dec, slit_wid=slit_wid, position_angle=position_angle, spat_resln=spat_resln, spec_resln=spec_resln, flux=flux, flux_global_sky=global_sky, flux_ivar=ivar, dist=dist, waveimg=waveimg, spat_rect=spat_rect, spec_rect=spec_rect, cache_path=spec2d_file.replace(".fits", "_rect.fits"), to_caches=True, **kwargs, )
[docs] @classmethod def from_fits(cls, fits_path: Optional[str] = None, **kwargs): """ Load 2D spectra from cache files. Parameters ---------- cache_path : str, optional (default: ".cache.json") The path to the cache file. """ if fits_path is None: raise ValueError("No fits file provided.") msgs.info(f"Loading 2D spectrum from {fits_path}...") try: f = fits.open(fits_path) header = fits.getheader(fits_path, ext=0) data = dict( pixel_scale=header["PIXSCALE"], center_ra=header["CENRA"], center_dec=header["CENDEC"], slit_wid=header["SLITWID"], position_angle=header["POSANG"], spat_resln=header["SPATRESLN"], spec_resln=header["SPECRESLN"], spat_rect=np.array(f["SPAT"].data, dtype=np.float64), spec_rect=np.array(f["SPEC"].data, dtype=np.float64), flux_rect=np.array(f["FLUX"].data, dtype=np.float64), flux_ivar_rect=np.array(f["IVAR"].data, dtype=np.float64), sky_offset=header["SKYOFFSET"], cache_path=header["SPECFILE"], to_caches=False, ) f.close() except FileNotFoundError: raise FileNotFoundError(f"Fits file {fits_path} not found.") return cls(**data, **kwargs)
[docs] @classmethod @show_and_save def coadd2d( cls, spec_data_list: list["SpecData"], **kwargs ) -> tuple["SpecData", np.ndarray | None]: """ Coadd multiple SpecData objects. Parameters ---------- spec_data_list : list[SpecData] A list of SpecData objects to be coadded. Returns ------- SpecData The coadded SpecData object. cr_mask : np.ndarray The mask of cosmic rays. """ from astropy.stats import mad_std if len(spec_data_list) == 0: raise ValueError("No SpecData object provided.") if len(spec_data_list) == 1: return spec_data_list[0], None # Check if the pixel scales are the same if not all( spec_data.pixel_scale == spec_data_list[0].pixel_scale for spec_data in spec_data_list ): raise ValueError("All SpecData objects must have the same pixel scale.") # Check if the spatial and spectral coordinates are the same if not all( (spec_data.spat_rect == spec_data_list[0].spat_rect).all() and (spec_data.spec_rect == spec_data_list[0].spec_rect).all() for spec_data in spec_data_list ): raise ValueError( "All SpecData objects must have the same spatial and spectral coordinates." ) # Check if the flux and ivar arrays have the same shape if not all( (spec_data.flux_rect.shape == spec_data_list[0].flux_rect.shape) and ( spec_data.flux_ivar_rect.shape == spec_data_list[0].flux_ivar_rect.shape ) for spec_data in spec_data_list ): raise ValueError( "All SpecData objects must have the same flux and ivar arrays." ) msgs.info(f"Coadding 2D spectra from {len(spec_data_list)} objects...") # Coadd the flux and ivar arrays flux_rect_stack = jnp.stack( [spec_data.flux_rect for spec_data in spec_data_list], axis=0 ) flux_ivar_rect_stack = jnp.stack( [spec_data.flux_ivar_rect for spec_data in spec_data_list], axis=0 ) flux_err_rect_stack = flux_ivar_rect_stack**-0.5 # Calculate weighted means valid_mask = np.isfinite(flux_rect_stack) & np.isfinite(flux_err_rect_stack) cr_mask = np.ones_like(flux_rect_stack, dtype=bool) pos_outlier = np.ones_like(flux_rect_stack, dtype=bool) neg_outlier = np.ones_like(flux_rect_stack, dtype=bool) diff = np.zeros_like(flux_rect_stack) if len(spec_data_list) > 1: # Try identifying cosmic rays with image subtraction for i in range(len(spec_data_list)): diff[i] = ( flux_rect_stack[i] - jnp.nanmedian(flux_rect_stack, axis=0) ) / flux_err_rect_stack[i] pos_outlier[i] = diff[i] > 5 * mad_std(diff[i][np.isfinite(diff[i])]) neg_outlier[i] = diff[i] < -5 * mad_std(diff[i][np.isfinite(diff[i])]) for i in range(len(spec_data_list)): # Mask the pixels with large positive deviations (cosmic rays) cr_mask[i] = ~pos_outlier[i] for j in range(len(spec_data_list)): if i != j: # Mask the pixels with large negative deviations in other frames cr_mask[i] &= ~neg_outlier[j] | neg_outlier[i] w = flux_ivar_rect_stack weights = np.where(valid_mask & cr_mask, w, 0) weighted_values = np.where(valid_mask & cr_mask, flux_rect_stack * w, 0) flux_rect = np.sum(weighted_values, axis=0) / np.sum(weights, axis=0) cmap = plt.get_cmap("gray") cmap.set_bad("red", 1.0) _, ax = plt.subplots( len(spec_data_list) + 1, 1, figsize=(12, 4 * (len(spec_data_list) + 1)), constrained_layout=True, sharex=True, sharey=True, ) for k in range(len(spec_data_list)): ax[k].imshow( diff[k], cmap=cmap, origin="lower", aspect="auto", vmin=-5, vmax=5 ) ax[k].set_title(f"Object {k + 1} - Mean") ax[k].set_ylabel(r"$\mathrm{Spat\ [arcsec]}$") ax[-1].imshow( flux_rect, cmap=cmap, origin="lower", aspect="auto", vmin=np.nanpercentile(flux_rect, 5), vmax=np.nanpercentile(flux_rect, 95), ) ax[-1].set_xlabel(r"$\mathrm{Spec\ [pixel]}$") ax[-1].set_ylabel(r"$\mathrm{Spat\ [arcsec]}$") ax[-1].set_title("Coadded") # Calculate errors weighted_errors = np.where( valid_mask & cr_mask, (flux_err_rect_stack * w) ** 2, 0 ) flux_err_rect = np.sqrt( np.sum(weighted_errors, axis=0) / np.sum(weights, axis=0) ** 2 ) flux_ivar_rect = np.where(np.isfinite(flux_err_rect), flux_err_rect**-2, 0) return ( cls( pixel_scale=spec_data_list[0].pixel_scale, center_ra=spec_data_list[0].center_ra, center_dec=spec_data_list[0].center_dec, slit_wid=spec_data_list[0].slit_wid, position_angle=spec_data_list[0].position_angle, spat_resln=spec_data_list[0].spat_resln, spec_resln=spec_data_list[0].spec_resln, spat_rect=spec_data_list[0].spat_rect, spec_rect=spec_data_list[0].spec_rect, flux_rect=flux_rect, flux_ivar_rect=jnp.asarray(flux_ivar_rect), **kwargs, ), cr_mask, )
[docs] def to_fits(self): """ Save the 2D spectra to cache files. Parameters ---------- cache_path : str, optional (default: ".cache.json") The path to the cache file. """ public_data = { key: value for key, value in self.__dict__.items() if not key.startswith("_") } for key in ["spat_rect", "spec_rect", "flux_rect", "flux_ivar_rect"]: # Convert the JAX array (if any) to numpy array public_data[key] = np.array(public_data[key]).tolist() # Create primary HDU primary_hdu = fits.PrimaryHDU() # Add headers hdr = primary_hdu.header hdr["PIXSCALE"] = public_data["pixel_scale"] hdr["CENRA"] = public_data["center_ra"] hdr["CENDEC"] = public_data["center_dec"] hdr["SLITWID"] = public_data["slit_wid"] hdr["POSANG"] = public_data["position_angle"] hdr["SPATRESLN"] = public_data["spat_resln"] hdr["SPECRESLN"] = public_data["spec_resln"] hdr["SPECFILE"] = public_data["cache_path"] hdr["SKYOFFSET"] = public_data["sky_offset"] # Create HDUs spat_hdu = fits.ImageHDU(public_data["spat_rect"], name="SPAT") spat_hdu.header["UNIT"] = "arcsec" spat_hdu.header["COMMENT"] = "Spatial coordinates" spec_hdu = fits.ImageHDU(public_data["spec_rect"], name="SPEC") spec_hdu.header["UNIT"] = "Angstrom" spec_hdu.header["COMMENT"] = "Spectral coordinates" flux_hdu = fits.ImageHDU(public_data["flux_rect"], name="FLUX") flux_hdu.header["COMMENT"] = "Flux values" ivar_hdu = fits.ImageHDU(public_data["flux_ivar_rect"], name="IVAR") ivar_hdu.header["COMMENT"] = "Inverse variance values" # Create HDU list hdul = fits.HDUList([primary_hdu, spat_hdu, spec_hdu, flux_hdu, ivar_hdu]) # Save to fits file # fits_path = self.spec2d_file.replace(".fits", "_rect.fits") hdul.writeto(public_data["cache_path"], overwrite=True)
[docs] def to_SpecModel( self, # slit_len: float = None, # spec_range: tuple[float, float] | list[float] = None, **kwargs, ) -> SpecModel: """ Convert the 2D spectra to a SpecModel object. Parameters ---------- spec_range : tuple or list, optional (default: None) The range of the spectral pixels to include. """ # if slit_len is None: # spat_mask = jnp.ones_like(self.spat_rect, dtype=bool) # else: # spat_mask = (self.spat_rect >= -slit_len / 2) & (self.spat_rect <= slit_len / 2) # if spec_range is None: # spec_mask = jnp.ones_like(self.spec_rect, dtype=bool) # else: # spec_mask = (self.spec_rect >= spec_range[0]) & (self.spec_rect <= spec_range[1]) # Update the spectral and spatial resolutions if specified in the config file spec_resln_cfg = kwargs.pop("spec_resln", None) spat_resln_cfg = kwargs.pop("spat_resln", None) if spec_resln_cfg is not None: self.spec_resln = spec_resln_cfg msgs.info( f"Spectral resolution specified in the config file: setting it to {self.spec_resln:.2f} Ang." ) if spat_resln_cfg is not None: self.spat_resln = spat_resln_cfg msgs.info( f"Spatial resolution specified in the config file: setting it to {self.spat_resln:.2f} arcsec." ) return SpecModel( dat=self.flux_rect, dat_err=self.flux_ivar_rect**-0.5, spat=self.spat_rect, spec=self.spec_rect, pixel_scale=self.pixel_scale, center_ra=self.center_ra, center_dec=self.center_dec, slit_wid=self.slit_wid, position_angle=self.position_angle, spat_resln=self.spat_resln, spec_resln=self.spec_resln, **kwargs, )
[docs] def update_pypeit_skymodel(self, spec_model: SpecModel, spec2d_file: str): """ Update the sky model and the mask in the PypeIt spec2d file. """ import os from pypeit import spec2dobj # https://pypeit.readthedocs.io/en/1.17.0/out_masks.html BIT_CR = 2**1 # Cosmic rays BIT_OFFSLIT = 2**4 # Off-slit pixels preproc_file = spec2d_file.replace(".fits", "_preproc.fits") rect_file = spec2d_file.replace(".fits", "_rect.fits") output_file = spec2d_file.replace(".fits", "_hostsub.fits") if not (os.path.exists(spec2d_file) | os.path.exists(rect_file)): raise FileNotFoundError("Spec2D file or rectified file not found.") # The new sky model will be composed of ###### (i) the preliminary global sky (average across the slit) removed before the rectification ###### (ii) the global sky (average of the sky region) subtracted after the rectification ###### (iii) the local sky modeled with HostSub # Preliminary sky line removal - reduce the noise introduced in the rectification hdul_preproc = fits.open(preproc_file) waveimg = jnp.array(hdul_preproc["WAVEIMG"].data, dtype=jnp.float32) dist = jnp.array(hdul_preproc["DIST"].data - self.sky_offset, dtype=jnp.float32) global_sky_pre = jnp.array(hdul_preproc["GLOBALSKY"].data, dtype=jnp.float32) det = hdul_preproc[0].header["DET"] hdul_preproc.close() sci2d = spec2dobj.Spec2DObj.from_file(spec2d_file, detname=det) # Mask the regions outside the wavelength/dist range offslit = np.asarray( np.where( (waveimg < spec_model.spec.item(0)) | (waveimg > spec_model.spec.item(-1)) | (dist < spec_model.spat_edges["host"][0]) | (dist > spec_model.spat_edges["host"][1]), BIT_OFFSLIT, 0, ), dtype=np.int16, ) sci2d.bpmmask.mask = sci2d.bpmmask.mask | offslit.T # Update the bmpmask to include cosmic rays hdul_rect = fits.open(rect_file) if "CR_MASK" in hdul_rect: mask_cr = jnp.argwhere( ~jnp.array(hdul_rect["CR_MASK"].data, dtype=jnp.bool) ) wave_rect = jnp.array(hdul_rect["SPEC"].data, dtype=jnp.float32) dist_rect = jnp.array(hdul_rect["SPAT"].data, dtype=jnp.float32) # The mask of the pixels with cosmic rays coords_rect = jnp.meshgrid(dist_rect, wave_rect, indexing="ij") # First create the interpolator object for the cosmic ray mask interpolator = Interp2D_Scipy(method="nearest") # Create the coordinates for the rectified grid # coords_rect is already [wave_rect, dist_rect] from meshgrid points = jnp.stack(coords_rect, axis=-1) # Create the mask values (2 for cosmic rays, 0 for non-cosmic rays) mask_values = jnp.zeros(points.shape[:-1], dtype=jnp.int32) mask_values = mask_values.at[mask_cr[:, 0], mask_cr[:, 1]].set(BIT_CR) # Fit the interpolator with the rectified grid points and mask values interpolator.fit(points, mask_values) # Create query points from the non-uniform grid query_points = jnp.stack([dist, waveimg], axis=-1) # Ensure the query points are finite query_points = jnp.where(np.isfinite(query_points), query_points, 0) # Interpolate the mask onto the non-uniform grid mapped_cr_mask = np.asarray( interpolator.predict(query_points.reshape(-1, 2)).reshape( waveimg.shape ), dtype=np.int16, ) # Update the mask in the PypeIt spec2d file sci2d.bpmmask.mask = sci2d.bpmmask.mask | mapped_cr_mask.T hdul_rect.close() # The global sky subtracted after the rectification assert spec_model.f_sky_1d.y is not None, ( "The global sky model is not available." ) global_sky_post = Interp1D_Grid( points=spec_model.f_sky_1d.X.ravel(), values=spec_model.f_sky_1d.y, method="cubic", )(waveimg) # The local sky x = np.stack([dist.ravel(), waveimg.ravel()], axis=-1).reshape(-1, 2) spat_mask = (dist.ravel() >= spec_model.spat_edges["host"][0]) & ( dist.ravel() <= spec_model.spat_edges["host"][-1] ) spec_mask = (waveimg.ravel() >= spec_model.spec[0]) & ( waveimg.ravel() <= spec_model.spec[-1] ) local_mask = spat_mask & spec_mask x_mask = jnp.asarray(x[local_mask], dtype=jnp.float32) sky_host_prior, _ = spec_model.host_prior(x_mask) _, _, (sky_pred, _) = spec_model._get_pred( spec_model._gp_1d, spec_model._gp_2d, x_mask, return_var=True ) sky_local = np.zeros_like(dist.ravel()) sky_local[local_mask] = sky_pred + sky_host_prior sky_local = sky_local.reshape(dist.shape) # assert sci2d.sciimg.shape == sky_model.T.shape sci2d.skymodel = np.array(global_sky_pre + global_sky_post + sky_local).T all_spec2d = spec2dobj.AllSpec2DObj() all_spec2d[det] = sci2d hdul = fits.open(spec2d_file) pri_hdr = hdul[0].header hdul.close() all_spec2d.write_to_fits(output_file, pri_hdr=pri_hdr)
def _rectify( self, points: ArrayLike, f_values: tuple[Array, Array], batch_size: int = 1024, interp_method: str = "rbf", ) -> tuple[Array, Array]: """ Rectify the 2D spectrum onto a grid. Parameters ---------- points : ArrayLike The spatial and spectral pixel coordinates. f_values : tuple[ArrayLike, ArrayLike, ArrayLike] The flux, ivar, and flag values. batch_size : int, optional (default: 1024, in pixels) The batch size for interpolation. interp_method : str, optional (default: "rbf") The interpolation method. """ if ~jnp.all(jnp.isfinite(points)): raise ValueError("Some points are NaN.") points = jnp.asarray(points, dtype=jnp.float32) if interp_method not in ["rbf", "linear", "scipy"]: raise ValueError("Invalid interpolation method.") elif interp_method == "rbf": msgs.info("Interpolating the flux with RBF...") elif interp_method == "linear": msgs.info("Interpolating the flux with linear method...") else: msgs.info("Interpolating the flux with scipy.interpolate.griddata...") # Initialize the flux and ivar arrays on a semi-rectified grid # The spatial/spectral coordinate monitonically increases in each row/column flux, ivar = f_values spec_pix_rect, spat_pix_rect = jnp.meshgrid(self.spec_rect, self.spat_rect) spec_batch_idx = jnp.array_split( jnp.arange(len(self.spec_rect)), len(self.spec_rect) // batch_size + 1 ) flux_rect = np.zeros((len(self.spat_rect), len(self.spec_rect))) flux_ivar_rect = np.zeros((len(self.spat_rect), len(self.spec_rect))) scales = (self.spat_resln / 2.355, self.spec_resln / 2.355) if interp_method == "rbf": interp2d = Interp2D_RBF(kernel="gaussian", scales=scales) interp2d_ivar = Interp2D_RBF(kernel="gaussian", scales=scales) elif interp_method == "linear": interp2d = Interp2D_Linear(scales=scales) interp2d_ivar = Interp2D_Linear(scales=scales) elif interp_method == "scipy": interp2d = Interp2D_Scipy(method="linear", scales=scales) interp2d_ivar = Interp2D_Scipy(method="linear", scales=scales) # Interpolate the flux in batches for idx_list in spec_batch_idx: msgs.info( f"Interpolating the flux in the spectral range {self.spec_rect[idx_list[0]]:.2f} - {self.spec_rect[idx_list[-1]]:.2f} Ang..." ) # The range of the spectrum to interpolate spec_min = idx_list.item(0) spec_max = idx_list.item(-1) + 1 # Initialize the interpolators (with padding) fit_min = max(spec_min - 1, 0) fit_max = min(spec_max + 1, points.shape[1]) interp2d.fit( points=points[:, fit_min:fit_max], values=flux[:, fit_min:fit_max] ) interp2d_ivar.fit( points=points[:, fit_min:fit_max], values=ivar[:, fit_min:fit_max] ) # Query points query_points = jnp.stack( [ spat_pix_rect[:, spec_min:spec_max].ravel(), spec_pix_rect[:, spec_min:spec_max].ravel(), ], axis=-1, ) flux_rect[:, spec_min:spec_max] = interp2d.predict( query_points=query_points ).reshape(flux_rect[:, spec_min:spec_max].shape) flux_ivar_rect[:, spec_min:spec_max] = interp2d_ivar.predict( query_points=query_points ).reshape(flux_ivar_rect[:, spec_min:spec_max].shape) return jnp.asarray(flux_rect), jnp.asarray(flux_ivar_rect) @show_and_save def _get_offset( self, points: Array, flux: Array, host_prior: Callable, mask_wid: float = 2.0 ) -> float: """ Center the trace of the science object. """ def binned_mean_with_clipping( points: Array, values: Array, bins: int, bin_range: tuple[float, float], **kwargs, ): from astropy.stats import sigma_clip, mad_std points = jnp.asarray(points, dtype=jnp.float32) values = jnp.asarray(values, dtype=jnp.float32) bin_edges = np.linspace(bin_range[0], bin_range[1], bins + 1) bin_indices = np.digitize(points, bin_edges) - 1 # Initialize an array to store the sigma-clipped statistics obs = np.full(bins, np.nan) # Iterate over each bin n_bin = np.array([np.sum(bin_indices == i) for i in range(bins)]) med_n_bin = np.median(n_bin) std_n_bin = mad_std(n_bin) for i in range(bins): # Skip the bin if there are too few points # Sometimes std_n_bin can be zero, which means essentially all bins should be included if n_bin[i] < med_n_bin - 3 * (std_n_bin + 1): continue # Get the indices of points in the current bin bin_mask = bin_indices == i valid_mask = np.isfinite(values) # Extract the values in the current bin y_in_bin = values[bin_mask & valid_mask] # Apply sigma clipping to the values in the bin y_clipped = sigma_clip(y_in_bin, stdfunc="mad_std", **kwargs) assert isinstance(y_clipped, np.ma.MaskedArray), ( "Clipped values are not a masked array." ) # Compute the statistic (e.g., mean) of the clipped values (requires at least 80% valid data) if (~y_clipped.mask).sum() > 0.80 * len(y_clipped): obs[i] = np.nanmean(y_clipped[~y_clipped.mask]) obs[obs <= 0] = np.nan return obs slit_len_max = min( min(-self.spat_rect.item(0), self.spat_rect.item(-1)) * 2, 30 ) spat = self.spat_rect[np.abs(self.spat_rect) <= slit_len_max / 2] obs = binned_mean_with_clipping( points[..., 0], flux, bins=len(spat), bin_range=( spat.item(0) - self.pixel_scale / 2, spat.item(-1) + self.pixel_scale / 2, ), sigma=5, ) sci_obj_mask = (jnp.abs(spat) >= mask_wid) & jnp.isfinite(obs) # Profile from the prior - flux evaluated at the weighted-mean wavelength wv_mean = ( jnp.nanmean(points[..., 1] * flux) / jnp.nanmean(flux) * jnp.ones_like(spat) ) def corr_coef(offset): # Evaluate the profile from the prior at the offset position prior = host_prior(jnp.stack([spat - offset, wv_mean], axis=-1))[0] return jnp.corrcoef( (prior[sci_obj_mask] - jnp.nanmin(prior[sci_obj_mask])) / (jnp.nanmax(prior[sci_obj_mask]) - jnp.nanmin(prior[sci_obj_mask])), (obs[sci_obj_mask] - jnp.nanmin(obs[sci_obj_mask])) / (jnp.nanmax(obs[sci_obj_mask]) - jnp.nanmin(obs[sci_obj_mask])), )[0, 1] offset_list = np.arange(-2, 2 + self.pixel_scale / 10, self.pixel_scale / 10) ccf = jax.vmap(corr_coef)(offset_list) # F_obs(x_spat) = F_prior(x_spat - offset_opt) # => F_obs(offset_opt) = F_prior(0) = Location of the SN # => Subtract offset_opt from the spatial coordinates of the 2D spectrum offset_opt = offset_list[np.argmax(ccf)] _, ax = plt.subplots(2, 1, figsize=(12, 8), constrained_layout=True) ax[0].plot(offset_list, ccf) ax[0].set_xlabel(r"$\mathrm{SCI - STD\ [arcsec]}$") ax[0].set_ylabel(r"$\mathrm{Correlation Coefficient}$") ax[0].xaxis.set_major_locator(plt.MultipleLocator(0.2)) ax[0].xaxis.set_minor_locator(plt.MultipleLocator(0.02)) ax[0].axvline(offset_opt, color="k", linestyle="--") ax[1].scatter( spat, (obs - jnp.nanmin(obs)) / (jnp.nanmax(obs) - jnp.nanmin(obs)), label="obs", ) profile_prior = host_prior( jnp.stack( [ spat - offset_opt, jnp.nanmean(points[:, :, 1] * flux) / jnp.nanmean(flux) * jnp.ones_like(spat), ], axis=-1, ) )[0] ax[1].scatter( spat, (profile_prior - profile_prior.min()) / (profile_prior.max() - profile_prior.min()), label="prior", ) ax[1].set_xlabel(r"$\mathrm{Spat\ [arcsec]}$") ax[1].set_ylabel(r"$\mathrm{Normalized\ Counts}$") ax[1].xaxis.set_major_locator(plt.MultipleLocator(5)) ax[1].xaxis.set_minor_locator(plt.MultipleLocator(1)) ax[1].legend() return offset_opt