Source code for hostsub_gp.host_image

# hostsub_gp/host_image.py

__all__ = ["PS1Image", "SDSSImage"]

import numpy as np

from astropy.io import fits
from astropy import units as u
from astropy.table import Table
from astropy.coordinates import SkyCoord
from astropy.stats import mad_std

import subprocess
import os

from ._utils import msgs
from ._utils._astronometry import query_astrometry_net_wcs

from numpy.typing import NDArray


class ImageProduct:
    """
    Class to save main data products after resampling an archival image onto the slit
    """

    def __init__(
        self,
        center_ra: float,
        center_dec: float,
        slit_len: float,
        slit_wid: float,
        position_angle: float,
        img: NDArray,
        header: fits.Header,
        flt: str,
        wv_eff: float,
    ):
        from astropy.wcs import WCS
        from astropy.wcs.utils import proj_plane_pixel_scales
        from reproject import reproject_adaptive

        self.flt = flt
        self.wv_eff = wv_eff

        # Load FITS image and WCS info
        wcs = WCS(header)

        # Step 1. Get the position angle and pixel scale of the image cutout
        if wcs.wcs.has_cd():
            pixel_scale = proj_plane_pixel_scales(wcs)[0] * 3600  # arcsec/pixel
        else:
            pixel_scale = wcs.wcs.cdelt[0] * 3600  # arcsec/pixel

        # Step 2. Resample the image to the slit with a given position angle

        slit_len_pix = slit_len / pixel_scale  # Slit length in pixels
        slit_wid_pix = slit_wid / pixel_scale  # Slit width in pixels
        shape = (
            int(np.ceil(slit_len_pix / 2)) * 2,
            int(np.ceil(slit_wid_pix / 2))
            * 2
            * 10,  # 10 times the width - room for blurring
        )

        # Define the target wcs
        wcs_target = WCS(naxis=2)
        wcs_target.wcs.ctype = ["RA---TAN", "DEC--TAN"]
        wcs_target.wcs.crval = [center_ra, center_dec]
        wcs_target.wcs.crpix = [(shape[1] + 1) / 2, (shape[0] + 1) / 2]

        # Include rotation in CD matric
        # Add 90 degrees to the position angle to align slit with the x-axis - w.r.t. the west
        theta = np.deg2rad(position_angle) + np.pi
        wcs_target.wcs.cd = (
            np.array(
                [
                    [
                        -pixel_scale * np.cos(theta),
                        -pixel_scale * np.sin(theta),
                    ],
                    [pixel_scale * np.sin(theta), -pixel_scale * np.cos(theta)],
                ]
            )
            / 3600
        )

        # Set additional required WCS parameters
        wcs_target.wcs.radesys = "ICRS"
        wcs_target.wcs.equinox = 2000.0

        # Reproject the image to the slit-aligned WCS
        data_reproj, _ = reproject_adaptive(
            (img, wcs),
            wcs_target,
            shape_out=shape,
        )

        self.img = data_reproj
        # import matplotlib.pyplot as plt
        # err = mad_std(
        #     self.img[
        #         np.isfinite(self.img) & (self.img < np.nanpercentile(self.img, 50))
        #     ]
        # ) / np.sqrt(1 - 2 / np.pi)
        bins, bin_edges = np.histogram(
            self.img[np.isfinite(self.img)],
            bins=200,
        )
        bin_center = 0.5 * (bin_edges[1:] + bin_edges[:-1])
        mode = bin_center[np.argmax(bins)]
        self.err = mad_std(
            self.img[
                np.isfinite(self.img)
                & (self.img < mode)
            ]
        ) #/ np.sqrt(1 - 2 / np.pi)
        self.pixel_scale = pixel_scale
        self.shape = shape
        self.slit_len_pix = slit_len_pix
        self.slit_wid_pix = slit_wid_pix
        self.wcs = wcs_target

        # Obtain the pixel coordinates along the slit
        self.spat_slit = (
            np.arange(self.shape[0]) + 1 - self.wcs.wcs.crpix[1]
        ) * self.pixel_scale

        # Obtain the pixel coordinates perpendicular to the slit
        self.spat_slit_wid = (
            np.arange(self.shape[1]) + 1 - self.wcs.wcs.crpix[0]
        ) * self.pixel_scale


class ArchivalImage:
    """
    Base class for loading images from archival services
    """

    def __init__(self, ra: float, dec: float, filters: str, path: str):
        self.ra = ra
        self.dec = dec
        self.filters = filters
        self.path = path
        if not os.path.exists(path):
            os.makedirs(path)
        self.wv_eff_dict: dict[str, float] = dict()

    def check_exists(self) -> bool:
        """
        Check if images already exist in the path
        """

        for flt in self.filters:
            file = f"{self.path}/{flt}.fits"
            if not os.path.exists(file):
                return False
        return True

    def load(self) -> tuple[list, list]:
        """
        Load images from the path
        """
        data_list = []
        header_list = []
        loaded_filters = []

        for flt in self.filters:
            file = f"{self.path}/{flt}.fits"
            try:
                # Check if the file exists and is not empty
                data = fits.getdata(file)
                if np.all(np.isnan(data)) or np.all(data == 0):
                    msgs.warning(f"No data for filter {flt}.")
                    self.filters = "".join([x for x in self.filters if x != flt])
                    continue
            except FileNotFoundError:
                msgs.warning(f"File {file} not found.")
                continue
            file_wcs = f"{self.path}/{flt}_wcs.fits"
            file_orig = f"{self.path}/{flt}.fits"
            
            if os.path.exists(file_wcs):
                file_to_load = file_wcs
                msgs.info(f"Loading WCS calibrated file: {file_wcs}")
            elif os.path.exists(file_orig):
                file_to_load = file_orig
                msgs.warning(f"Loading original file (WCS not calibrated by Astrometry.net): {file_orig}")
            else:
                msgs.warning(f"No file found for filter {flt}.")
                continue

            # Load the file
            try:
                data = fits.getdata(file_to_load)
                header = fits.getheader(file_to_load)
                
                # Store successful load
                data_list.append(data)
                header_list.append(header)
                loaded_filters.append(flt)
                
            except Exception as e:
                msgs.warning(f"Failed to load {file_to_load}: {e}")
                continue

        # Update filters with successfully loaded ones
        self.filters = loaded_filters
        return data_list, header_list

    def get_cutout(self, overwrite: bool = False):
        """
        Get cutout images from the archival service
        """

        # Check if images already exist
        if not overwrite and self.check_exists():
            msgs.info("Images already exist.")
        else:
            self.download(overwrite=overwrite)

        # WCS calibration with Astrometry.net
        query_astrometry_net_wcs(self.path, overwrite=overwrite)

    def download(self, overwrite: bool = True):
        """
        Download images from the archival service
        """
        raise NotImplementedError("Subclasses should implement this method.")


[docs] class PS1Image(ArchivalImage): """ Class to load images from the PS1 Image Cutout Service """ def __init__( self, ra: float, dec: float, filters: str = "grizy", path: str = "./ps1_cutout", size: int = 2400, ): super().__init__(ra, dec, filters, path) self.size = size self.wv_eff_dict = dict(g=4810.16, r=6155.47, i=7503.03, z=8668.36, y=9613.60)
[docs] def download(self, overwrite: bool = False): """ Download images from the PS1 Image Cutout Service """ if self.dec < -30: msgs.warning("PS1 images are not available for declinations < -30 degrees.") return # Check if images already exist if not overwrite and self.check_exists(): msgs.info("PS1 images already exist.") return fitsurl = self._geturl() if len(fitsurl) == 0: msgs.warning("No images found in the PS1 database.") return os.makedirs(self.path, exist_ok=True) for k, flt in enumerate(self.filters): fitspath = f"{self.path}/{flt}.fits" if os.path.exists(fitspath) and not overwrite: continue subprocess.run(["wget", fitsurl[k], "-O", fitspath]) query_astrometry_net_wcs(self.path, overwrite=overwrite)
def _getimages(self): """Query ps1filenames.py service to get a list of images""" service = "https://ps1images.stsci.edu/cgi-bin/ps1filenames.py" url = f"{service}?ra={self.ra}&dec={self.dec}&filters={self.filters}" table = Table.read(url, format="ascii") return table def _geturl(self): """Get URL for images in the table""" table = self._getimages() url = ( f"https://ps1images.stsci.edu/cgi-bin/fitscut.cgi?" f"ra={self.ra}&dec={self.dec}&size={self.size}&format=fits" ) # sort filters from red to blue flist = ["grizy".find(x) for x in table["filter"]] table = table[np.argsort(flist)] urlbase = url + "&red=" url = [] for filename in table["filename"]: url.append(urlbase + filename) return url
[docs] class SDSSImage(ArchivalImage): """ Class to load images from astroquery.sdss.SDSS """ def __init__( self, ra: float, dec: float, filters: str = "ugriz", path: str = "./sdss_cutout" ): super().__init__(ra, dec, filters, path) self.wv_eff_dict = dict(u=3556.52, g=4702.50, r=6175.58, i=7489.98, z=8946.71)
[docs] def download(self, overwrite: bool = False) -> None: """ Download images from astroquery.sdss.SDSS """ # Check if images already exist if not overwrite and self.check_exists(): msgs.info("SDSS images already exist.") return from astroquery.sdss import SDSS # Define target coordinates (RA, Dec) coord = SkyCoord(ra=self.ra, dec=self.dec, unit="deg") # Query the SDSS image cutout result = SDSS.query_region(coord, radius=30 * u.arcsec) if result is None or len(result) == 0: msgs.warning("No images found in the SDSS database.") return # Get the images os.makedirs(self.path, exist_ok=True) for flt in self.filters: img = SDSS.get_images(matches=result, band=flt) if img is None: continue elif len(img) == 0: continue fitspath = f"{self.path}/{flt}.fits" if os.path.exists(fitspath) and not overwrite: continue img[0][0].writeto(fitspath)
class LSImage(ArchivalImage): """ Class to load images from the DESI Legacy Imaging Surveys """ def __init__( self, ra: float, dec: float, filters: str = "griz", path: str = "./ls_cutout" ): super().__init__(ra, dec, filters, path) # https://noirlab.edu/science/programs/ctio/filters/Dark-Energy-Camera self.wv_eff_dict = dict(g=4730, r=6420, i=7840, z=9260) self.pixel_scale = 0.25 self.size = 10 # arcmin def download(self, overwrite: bool = False): """ Download images from the DESI Legacy Imaging Surveys """ # Check if images already exist if not overwrite and self.check_exists(): msgs.info("LS images already exist.") return url = "https://www.legacysurvey.org/viewer/fits-cutout?ra={ra:.4f}&dec={dec:.4f}&pixscale={px_scale}&layer=ls-dr10&size={size:g}".format( ra=self.ra, dec=self.dec, px_scale=self.pixel_scale, size=self.size * 60 / self.pixel_scale, ) status = self._download_url(url=url, outfile=f"{self.path}/dummy.fits") # Postprocessing if status: hdu = fits.open(f"{self.path}/dummy.fits") hdu_header = hdu[0].header for k, f in enumerate([*str(hdu_header["bands"])]): hdu_header["FILTER"] = f hdu_header["GAIN"] = 100 data = hdu[0].data[k, :, :] fits.writeto( f"{self.path}/{f}.fits", data, hdu_header, output_verify="silentfix", overwrite=True, ) # If the filter is missing, update the filters list if np.all(np.isnan(data)) or np.all(data == 0): msgs.warning(f"No data for filter {f}.") self.filters = "".join([x for x in self.filters if x != f]) continue # Remove the dummy file os.remove(f"{self.path}/dummy.fits") else: msgs.warning("Failed to download LS images.") msgs.warning( "Field might not be covered by DESI LS or https://www.legacysurvey.org/viewer might be offline." ) return def _download_url(self, url: str, outfile: str, max_attempts: int = 5) -> bool: """ Download a file from a URL using wget with retry logic. Adapted from: https://github.com/ygordon/cutout_fetching/blob/master/cutout_fetching_lite.py """ # Often encounter the following error: # urllib.error.HTTPError: HTTP Error 504: Gateway Time-out # Repeat the download attempt for up to `max_attempts` tries # Return True if the download was successful import time from urllib.error import HTTPError, URLError for attempt in range(max_attempts): try: subprocess.run(["wget", url, "-O", outfile], check=True) return True except HTTPError as e: print( f"Failed attempt {attempt} to download {outfile} with an HTTPError: {e}" ) except URLError as e: print( f"Failed attempt {attempt} to download {outfile} with a URLError: {e}" ) time.sleep(1) print(f"Failed to download image {outfile}") return False