Source code for lenstronomy.LightModel.Profiles.starlets

__author__ = "aymgal"

import numpy as np
import warnings

from lenstronomy.LightModel.Profiles import starlets_util
from lenstronomy.LightModel.Profiles.interpolation import Interpol
from lenstronomy.Util import util

__all__ = ["SLIT_Starlets"]


[docs]class SLIT_Starlets(object): """Decomposition of an image using the Isotropic Undecimated Walevet Transform, also known as "starlet" or "B-spline", using the 'a trous' algorithm. Astronomical data (galaxies, stars, ...) are often very sparsely represented in the starlet basis. Based on Starck et al. : https://ui.adsabs.harvard.edu/abs/2007ITIP...16..297S/abstract """ param_names = ["amp", "n_scales", "n_pixels", "scale", "center_x", "center_y"] param_names_latex = { r"$I_0$", r"$n_{\rm scales}$", r"$n_{\rm pix}$", r"scale", r"$x_0$", r"$y_0$", } lower_limit_default = { "amp": [0], "n_scales": 2, "n_pixels": 5, "center_x": -1000, "center_y": -1000, "scale": 0.000000001, } upper_limit_default = { "amp": [1e8], "n_scales": 20, "n_pixels": 1e10, "center_x": 1000, "center_y": 1000, "scale": 10000000000, }
[docs] def __init__( self, thread_count=1, fast_inverse=True, second_gen=False, show_pysap_plots=False, force_no_pysap=False, ): """Load pySAP package if found, and initialize the Starlet transform. :param thread_count: number of threads used for pySAP computations :param fast_inverse: if True, reconstruction is simply the sum of each scale (only for 1st generation starlet transform) :param second_gen: if True, uses the second generation of starlet transform :param show_pysap_plots: if True, displays pySAP plots when calling the decomposition method :param force_no_pysap: if True, does not load pySAP and computes starlet transforms in python. """ self.use_pysap, pysap = self._load_pysap(force_no_pysap) if self.use_pysap: self._transf_class = pysap.load_transform( "BsplineWaveletTransformATrousAlgorithm" ) else: warnings.warn( "The python package pySAP is not used for starlet operations. " "They will be performed using (slower) python routines." ) self._fast_inverse = fast_inverse self._second_gen = second_gen self._show_pysap_plots = show_pysap_plots self.interpol = Interpol() self.thread_count = thread_count
[docs] def function( self, x, y, amp=None, n_scales=None, n_pixels=None, scale=1, center_x=0, center_y=0, ): """1D inverse starlet transform from starlet coefficients stored in coeffs Follows lenstronomy conventions for light profiles. :param amp: decomposition coefficients ('amp' to follow conventions in other light profile) This is an ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) or (n_scales*n_pixels,) :param n_scales: number of decomposition scales :param n_pixels: number of pixels in a single scale :return: reconstructed signal as 1D array of shape (n_pixels,) """ if len(amp.shape) == 1: coeffs = util.array2cube(amp, n_scales, n_pixels) elif len(amp.shape) == 3: coeffs = amp else: raise ValueError( "Starlets 'amp' has not the right shape (1D or 3D arrays are supported)" ) image = self.function_2d(coeffs, n_scales, n_pixels) image = self.interpol.function( x, y, image=image, scale=scale, center_x=center_x, center_y=center_y, amp=1, phi_G=0, ) return image
[docs] def function_2d(self, coeffs, n_scales, n_pixels): """2D inverse starlet transform from starlet coefficients stored in coeffs. :param coeffs: decomposition coefficients, ndarray with shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 2D array of shape (sqrt(n_pixels), sqrt(n_pixels)) """ if self.use_pysap and not self._second_gen: return self._inverse_transform(coeffs, n_scales, n_pixels) else: return starlets_util.inverse_transform( coeffs, fast=self._fast_inverse, second_gen=self._second_gen )
[docs] def decomposition(self, image, n_scales): """1D starlet transform from starlet coefficients stored in coeffs. :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 1D array of shape (n_scales*n_pixels,) """ if len(image.shape) == 1: image_2d = util.array2image(image) elif len(image.shape) == 2: image_2d = image else: raise ValueError( "image has not the right shape (1D or 2D arrays are supported for starlets decomposition)" ) return util.cube2array(self.decomposition_2d(image_2d, n_scales))
[docs] def decomposition_2d(self, image, n_scales): """2D starlet transform from starlet coefficients stored in coeffs. :param image: 2D image to be decomposed, ndarray with shape (sqrt(n_pixels), sqrt(n_pixels)) :param n_scales: number of decomposition scales :return: reconstructed signal as 2D array of shape (n_scales, sqrt(n_pixels), sqrt(n_pixels)) """ if self.use_pysap and not self._second_gen: coeffs = self._transform(image, n_scales) else: coeffs = starlets_util.transform( image, n_scales, second_gen=self._second_gen ) return coeffs
def _inverse_transform(self, coeffs, n_scales, n_pixels): """Reconstructs image from starlet coefficients.""" self._check_transform_pysap(n_scales, n_pixels) if self._fast_inverse and not self._second_gen: # for 1st gen starlet the reconstruction can be performed by summing all scales image = np.sum(coeffs, axis=0) else: coeffs = self._coeffs2pysap(coeffs) self._transf.analysis_data = coeffs result = self._transf.synthesis() if self._show_pysap_plots: result.show() image = result.data return image def _transform(self, image, n_scales): """Decomposes an image into starlets coefficients.""" self._check_transform_pysap(n_scales, image.size) self._transf.data = image self._transf.analysis() if self._show_pysap_plots: self._transf.show() coeffs = self._transf.analysis_data coeffs = self._pysap2coeffs(coeffs) return coeffs def _check_transform_pysap(self, n_scales, n_pixels): """If needed, update the loaded pySAP transform to correct number of scales.""" if ( not hasattr(self, "_transf") or n_scales != self._n_scales or n_pixels != self._n_pixels ): self._transf = self._transf_class( nb_scale=n_scales, verbose=False, nb_procs=self.thread_count ) self._n_scales = n_scales self._n_pixels = n_pixels def _pysap2coeffs(self, coeffs): """Convert pySAP decomposition coefficients to numpy array.""" return np.asarray(coeffs) def _coeffs2pysap(self, coeffs): """Convert coefficients stored in numpy array to list required by pySAP.""" coeffs_list = [] for i in range(coeffs.shape[0]): coeffs_list.append(coeffs[i, :, :]) return coeffs_list def _load_pysap(self, force_no_pysap): """Load pySAP module.""" if force_no_pysap: return False, None try: import pysap except ImportError: return False, None else: return True, pysap
[docs] def delete_cache(self): """Delete the cached interpolated image.""" self.interpol.delete_cache()