Source code for lenstronomy.LightModel.Profiles.starlets_util

__author__ = "herjy", "aymgal", "sibirrer"

import numpy as np
from scipy import ndimage

from lenstronomy.Util.package_util import exporter

export, __all__ = exporter()


[docs]@export def transform(img, n_scales, second_gen=False): """Performs starlet decomposition of an 2D array. :param img: input image :param n_scales: number of decomposition scales :param second_gen: if True, 'second generation' starlets are used """ mode = "nearest" lvl = n_scales - 1 sh = np.shape(img) n1 = sh[1] n2 = sh[1] # B-spline filter h = [1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16] n = np.size(h) h = np.array(h) max_lvl = np.min((lvl, int(np.log2(n2)))) if lvl > max_lvl: raise ValueError( "Maximum decomposition level is {} (required: {})".format(max_lvl, lvl) ) elif lvl <= 0: raise ValueError("Number of decomposition level can not be non-positive") c = img # wavelet set of coefficients. wave = np.zeros((lvl + 1, n1, n2)) for i in range(lvl): newh = np.zeros((1, n + (n - 1) * (2**i - 1))) newh[0, np.linspace(0, np.size(newh) - 1, len(h), dtype=int)] = h # H = np.dot(newh.T, newh) ######Calculates c(j+1) ###### Line convolution cnew = ndimage.convolve1d(c, newh[0, :], axis=0, mode=mode) ###### Column convolution cnew = ndimage.convolve1d(cnew, newh[0, :], axis=1, mode=mode) if second_gen: ###### hoh for g; Column convolution hc = ndimage.convolve1d(cnew, newh[0, :], axis=0, mode=mode) ###### hoh for g; Line convolution hc = ndimage.convolve1d(hc, newh[0, :], axis=1, mode=mode) ###### wj+1 = cj - hcj+1 wave[i, :, :] = c - hc else: ###### wj+1 = cj - cj+1 wave[i, :, :] = c - cnew c = cnew wave[i + 1, :, :] = c return wave
[docs]@export def inverse_transform(wave, fast=True, second_gen=False): """Reconstructs an image fron its starlet decomposition coefficients. :param wave: input coefficients, with shape (n_scales, np.sqrt(n_pixel), np.sqrt(n_pixel)) :param fast: if True, and only with second_gen is False, simply sums up all scales to reconstruct the image :param second_gen: if True, 'second generation' starlets are used """ if fast and not second_gen: # simply sum all scales, including the coarsest one return np.sum(wave, axis=0) mode = "nearest" lvl, n1, n2 = np.shape(wave) h = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]) n = np.size(h) cJ = np.copy(wave[lvl - 1, :, :]) for i in range(1, lvl): newh = np.zeros((1, n + (n - 1) * (2 ** (lvl - 1 - i) - 1))) newh[0, np.linspace(0, np.size(newh) - 1, len(h), dtype=int)] = h H = np.dot(newh.T, newh) ###### Line convolution cnew = ndimage.convolve1d(cJ, newh[0, :], axis=0, mode=mode) ###### Column convolution cnew = ndimage.convolve1d(cnew, newh[0, :], axis=1, mode=mode) cJ = cnew + wave[lvl - 1 - i, :, :] return np.reshape(cJ, (n1, n2))