Source code for lenstronomy.LensModel.Solver.epl_shear_solver

__author__ = "ewoudwempe", "sibirrer"

import numpy as np
from lenstronomy.LensModel.Util.epl_util import (
    min_approx,
    pol_to_cart,
    cart_to_pol,
    cdot,
    ps,
    rotmat,
    solvequadeq,
    brentq_inline,
)
from lenstronomy.LensModel.Util.epl_util import pol_to_ell, ell_to_pol, geomlinspace
from lenstronomy.Util.image_util import findOverlap
from lenstronomy.LensModel.Profiles.epl_numba import alpha, omega
from lenstronomy.Util.numba_util import jit
from lenstronomy.Util.param_util import (
    ellipticity2phi_q,
    shear_cartesian2polar,
    shear_polar2cartesian,
)
from lenstronomy.LensModel.Profiles.shear import Shear


@jit()
def _alpha_epl_shear(x, y, b, q, t=1, gamma1=0, gamma2=0, Omega=None):
    """The complex deflection of EPL+SHEAR."""
    return (
        alpha(x, y, b, q, t=t, Omega=Omega)
        + (gamma1 * x + gamma2 * y)
        + 1j * (gamma2 * x - gamma1 * y)
    )


@jit()
def _one_dim_lens_eq_calcs(args, phi):
    """Calculates intermediate quantities that are needed for several of the subsequent
    functions."""
    b, t, y1, y2, q, gamma1, gamma2 = args
    y = y1 + y2 * 1j

    rhat = (
        1
        / (1 - gamma1**2 - gamma2**2)
        * (
            ((1 + gamma1) * np.cos(phi) + gamma2 * np.sin(phi))
            + 1j * (gamma2 * np.cos(phi) + (1 - gamma1) * np.sin(phi))
        )
    )
    thetahat = (
        1
        / (1 - gamma1**2 - gamma2**2)
        * (
            ((1 + gamma1) * np.sin(phi) - gamma2 * np.cos(phi))
            + 1j * (gamma2 * np.sin(phi) - (1 - gamma1) * np.cos(phi))
        )
    )

    frac_Roverrsh, phiell = pol_to_ell(1, phi, q)
    Omega = omega(phiell, t, q)
    const = 2 * b / (1 + q)
    if abs(t - 1) > 1e-4:
        b_over_r_pow_tm1 = -cdot(y, thetahat) / (const * cdot(Omega, thetahat))
        R = b * np.abs(b_over_r_pow_tm1) ** (1 / (1 - t)) * np.sign(b_over_r_pow_tm1)
    else:
        Omega_ort = 1j * Omega
        x = ((1 - gamma1) * np.cos(phi) - gamma2 * np.sin(phi)) + 1j * (
            -gamma2 * np.cos(phi) + (1 + gamma1) * np.sin(phi)
        )
        R = cdot(Omega_ort, y) / cdot(Omega_ort, x) * frac_Roverrsh
    r, theta = ell_to_pol(R, phiell, q)
    return Omega, const, phiell, q, r, rhat, t, b, thetahat, y


@jit()
def _one_dim_lens_eq_both(phi, args):
    """Calculates and returns simultaneously both the smooth and the not-smooth
    1-dimensional lens equation that needs to be solved."""
    Omega, const, phiell, q, r, rhat, t, b, thetahat, y = _one_dim_lens_eq_calcs(
        args, phi
    )
    rr, thetaa = ell_to_pol(1, phiell, q)
    ip = cdot(y, rhat) * cdot(Omega, thetahat) - cdot(Omega, rhat) * cdot(y, thetahat)
    # The derivations are lost somewhere in my notes...
    eq = (rr * b) ** (2 / t - 2) * ps(
        (cdot(y, thetahat) / const), 2 / t
    ) * ip**2 + ps(ip, 2 / t) * cdot(Omega, thetahat) ** 2
    eq_notsmooth = ps(rr * b, 1 - t) * (cdot(y, thetahat) / const) * np.abs(
        ip
    ) ** t + ip * np.abs(cdot(Omega, thetahat)) ** (+t)
    return eq, eq_notsmooth


@jit()
def _getr(phi, args):
    """Given an angle phi, get the radius r."""
    Omega, const, phiell, q, r, rhat, t, b, thetahat, y = _one_dim_lens_eq_calcs(
        args, phi
    )
    return r


@jit()
def _one_dim_lens_eq(phi, args):
    """Calculates the smooth 1-dimensional lens equation to solve - to be used by a root-finder"""
    Omega, const, phiell, q, r, rhat, t, b, thetahat, y = _one_dim_lens_eq_calcs(
        args, phi
    )

    rr, thetaa = ell_to_pol(1, phiell, q)
    ip = cdot(y, rhat) * cdot(Omega, thetahat) - cdot(Omega, rhat) * cdot(y, thetahat)
    eq = (rr * b) ** (2 / t - 2) * ps(
        (cdot(y, thetahat) / const), 2 / t
    ) * ip**2 + ps(ip, 2 / t) * cdot(Omega, thetahat) ** 2
    return eq


@jit()
def _one_dim_lens_eq_unsmooth(phi, args):
    """Calculates the not-smooth 1-dimensional lens equation to to solve - to be used by a root-finder.
    For some parameters and solutions, numerical issues make solving this one feasible while the other is not.
    """
    Omega, const, phiell, q, r, rhat, t, b, thetahat, y = _one_dim_lens_eq_calcs(
        args, phi
    )

    rr, thetaa = ell_to_pol(1, phiell, q)
    ip = cdot(y, rhat) * cdot(Omega, thetahat) - cdot(Omega, rhat) * cdot(y, thetahat)
    eq_notsmooth = ps(rr * b, 1 - t) * (cdot(y, thetahat) / const) * np.abs(
        ip
    ) ** t + ip * np.abs(cdot(Omega, thetahat)) ** (+t)
    return eq_notsmooth


@jit()
def _getphi(thpl, args):
    """Finds all roots to both versions the 1-dimensional lens equation in phi, by doing
    a grid search for sign changes on the supplied thpl. In the case of extrema, refine
    at the relevant location.

    :param thpl: What points to calculate the equation on use for detecting sign changes
    :param args: Parameters to be passed to the lens equation
    :return: an array containing all roots
    """
    y, y_notsmooth = _one_dim_lens_eq_both(thpl, args)
    num_phi = len(thpl)
    roots = []
    for i in range(num_phi - 1):
        if y[i + 1] * y[i] <= 0:
            roots.append(
                brentq_inline(_one_dim_lens_eq, thpl[i], thpl[i + 1], args=args)
                % (2 * np.pi)
            )
        elif y_notsmooth[i + 1] * y_notsmooth[i] <= 0:
            roots.append(
                brentq_inline(
                    _one_dim_lens_eq_unsmooth, thpl[i], thpl[i + 1], args=args
                )
                % (2 * np.pi)
            )

    for i in range(1, num_phi + 1):
        y1 = y[i - 1]
        y2 = y[i % num_phi]
        y3 = y[(i + 1) % num_phi]
        y1n = y_notsmooth[(i - 1) % num_phi]
        y2n = y_notsmooth[i % num_phi]
        y3n = y_notsmooth[(i + 1) % num_phi]
        if (y3 - y2) * (y2 - y1) <= 0 or (y3n - y2n) * (y2n - y1n) <= 0:
            if y3 * y2 <= 0 or y1 * y2 <= 0:
                continue
            if i > num_phi - 2:
                continue
            else:
                x1 = thpl[i - 1]
                x2 = thpl[i]
                x3 = thpl[(i + 1)]

            xmin = min_approx(x1, x2, x3, y1, y2, y3)
            xmin_ns = min_approx(x1, x2, x3, y1n, y2n, y3n)
            ymin = _one_dim_lens_eq(xmin, args)
            ymin_ns = _one_dim_lens_eq_unsmooth(xmin_ns, args)
            if ymin * y2 <= 0 and x2 <= xmin <= x3:
                roots.append(
                    brentq_inline(_one_dim_lens_eq, x2, xmin, args=args) % (2 * np.pi)
                )
                roots.append(
                    brentq_inline(_one_dim_lens_eq, xmin, x3, args=args) % (2 * np.pi)
                )
            elif ymin * y2 <= 0 and x1 <= xmin <= x2:
                roots.append(
                    brentq_inline(_one_dim_lens_eq, x1, xmin, args=args) % (2 * np.pi)
                )
                roots.append(
                    brentq_inline(_one_dim_lens_eq, xmin, x2, args=args) % (2 * np.pi)
                )
            elif ymin_ns * y2n <= 0 and x2 <= xmin_ns <= x3:
                roots.append(
                    brentq_inline(_one_dim_lens_eq_unsmooth, x2, xmin_ns, args=args)
                    % (2 * np.pi)
                )
                roots.append(
                    brentq_inline(_one_dim_lens_eq_unsmooth, xmin_ns, x3, args=args)
                    % (2 * np.pi)
                )
            elif ymin_ns * y2n <= 0 and x1 <= xmin_ns <= x2:
                roots.append(
                    brentq_inline(_one_dim_lens_eq_unsmooth, x1, xmin_ns, args=args)
                    % (2 * np.pi)
                )
                roots.append(
                    brentq_inline(_one_dim_lens_eq_unsmooth, xmin_ns, x2, args=args)
                    % (2 * np.pi)
                )

    return np.array(roots)


[docs] def solvelenseq_majoraxis(args, Nmeas=200, Nmeas_extra=50): """Solve the lens equation, where the arguments have been properly rotated to the major-axis.""" b, t, y1, y2, q, gamma1, gamma2 = args p1 = np.arctan2(y2 * (1 - gamma1) + gamma2 * y1, y1 * (1 + gamma1) + gamma2 * y2) int_points = [p1] geom = geomlinspace(1e-4, 0.1, Nmeas_extra) thpl = np.sort( np.concatenate( ( np.linspace(0.0, np.pi, Nmeas), *[i % np.pi - geom for i in int_points], *[i % np.pi + geom for i in int_points], ) ) ) the = _getphi(thpl, (b, t, y1, y2, q, gamma1, gamma2)) thetas = np.concatenate((the, the + np.pi)) Rs = np.array([_getr(theta, (b, t, y1, y2, q, gamma1, gamma2)) for theta in thetas]) stuff = np.array(pol_to_cart(Rs[Rs > 0], thetas[Rs > 0])) diff = ( -y1 - y2 * 1j + stuff[0] + stuff[1] * 1j - _alpha_epl_shear(stuff[0], stuff[1], b, q, t, gamma1=gamma1, gamma2=gamma2) ) goodones = np.abs(diff) < 1e-8 return findOverlap(*stuff[:, goodones], 1e-8)
def _check_center(kwargs_lens): """Checks if the shear-at-center convention is properly used.""" # calculate (inverse) displacement caused by the offset between shear and lens centroid # this shift needs to be added to the source position such that the solution of the lens equation # without this shift in the shear is the correct one if len(kwargs_lens) > 1: shear = Shear() # calculate shift from the deflector centroid from the shear field alpha_x, alpha_y = shear.derivatives( kwargs_lens[0]["center_x"], kwargs_lens[0]["center_y"], **kwargs_lens[1] ) return alpha_x, alpha_y else: return 0, 0
[docs] def solve_lenseq_pemd(pos_, kwargs_lens, Nmeas=400, Nmeas_extra=80, **kwargs): """Solves the lens equation using a semi-analytical recipe. :param pos_: The source plane position (shape (2,)), or the source plane positions (shape (2,N)) for which to solve the lens equation :param kwargs_lens: List of kwargs in lenstronomy style, following ['EPL', 'SHEAR'] format :param Nmeas: resolution with which to sample the angular grid, higher means more reliable lens equation solving. For solving many positions at once, you may want to set this higher. :param Nmeas_extra: resolution with which to additionally sample the angular grid at the low-shear end, higher means more reliable lens equation solving. For solving many positions at once, you may want to set this higher. :return: The lens plane positions. Note: generally the (demagnified) central image will also be included. """ pos = np.asarray(pos_) if pos.ndim > 1 and pos.shape[-1] != 1: pos = pos[..., None] t = kwargs_lens[0]["gamma"] - 1 if "gamma" in kwargs_lens[0] else 1 if "e1" in kwargs_lens[0] and "e2" in kwargs_lens[0]: e1 = kwargs_lens[0]["e1"] e2 = kwargs_lens[0]["e2"] else: e1, e2 = 0, 0 theta_ell, q = ellipticity2phi_q(e1, e2) b = kwargs_lens[0]["theta_E"] * np.sqrt(q) if len(kwargs_lens) > 1: gamma = kwargs_lens[1]["gamma1"] + 1j * kwargs_lens[1]["gamma2"] else: gamma = 0 + 0j shift_x, shift_y = _check_center(kwargs_lens) shift = shift_x + 1j * shift_y cen = kwargs_lens[0]["center_x"] + 1j * kwargs_lens[0]["center_y"] p = pos[0] + 1j * pos[1] - cen + shift rotfact = np.exp(-1j * theta_ell) gamma *= rotfact**2 p *= rotfact res = solvelenseq_majoraxis( (b, t, p.real, p.imag, q, gamma.real, gamma.imag), Nmeas=Nmeas, Nmeas_extra=Nmeas_extra, ) xsol, ysol = res x = np.array([(xs + 1j * ys) / rotfact + cen for xs, ys in zip(xsol, ysol)]) return x.real, x.imag
[docs] def caustics_epl_shear( kwargs_lens, num_th=500, maginf=0, sourceplane=True, return_which=None ): """Analytically calculates the caustics of an EPL+shear lens model. Since for gamma>2, the outer critical curve does not exist, the option to find the curves for a set, finite magnification exists, by supplying maginf, so that the routine finds the curve of this magnification, rather than the true caustic. :param kwargs_lens: List of kwargs in lenstronomy style, following ['EPL', 'SHEAR'] format :param num_th: resolution. :param maginf: the outer critical curve for t>1 will be replaced with the curve where the inverse magnification is maginf :param sourceplane: if True (default), ray-shoot the calculated critical curves to the source plane :param return_which: options 'quad' (boundary of area within which there are 4 images), 'double' (boundary of area within which there are 2 images), 'caustic' (the diamond caustic) and 'cut' (the cut, if it exists, that is if t<2, else, if t>2, returns the caustic) and None (in that case: return quad, caustic, cut) :return: (2,N) array if return_which set, else a tuple of (caustic, cut, quad) """ e1, e2 = kwargs_lens[0]["e1"], kwargs_lens[0]["e2"] if len(kwargs_lens) > 1: gamma1unr, gamma2unr = kwargs_lens[1]["gamma1"], kwargs_lens[1]["gamma2"] else: gamma1unr, gamma2unr = 0, 0 t = kwargs_lens[0]["gamma"] - 1 if "gamma" in kwargs_lens[0] else 1 theta_ell, q = ellipticity2phi_q(e1, e2) theta_gamma, gamma_mag = shear_cartesian2polar(gamma1unr, gamma2unr) b = np.sqrt(q) * kwargs_lens[0]["theta_E"] cen = np.expand_dims( np.array([kwargs_lens[0]["center_x"], kwargs_lens[0]["center_y"]]), 1 ) theta_gamma -= theta_ell gamma1, gamma2 = shear_polar2cartesian(theta_gamma, gamma_mag) M = rotmat(-theta_ell) phiell, q = ellipticity2phi_q(e1, e2) theta = np.linspace(0, 2 * np.pi, num_th, endpoint=False) r = 1 R, phi = pol_to_ell(1, theta, q) Omega = omega(phi, t, q) aa = 1 bb = -(2 - t) frac_roverR = r / R cc = ( (1 - t) * (2 - t) * (cdot(np.exp(1j * theta), Omega)) / frac_roverR * 2 / (1 + q) ) cc -= (1 - t) ** 2 * (2 / (1 + q)) ** 2 * np.abs(Omega) ** 2 / frac_roverR**2 # Shear stuff: gammaint_fac = ( -np.exp(2j * theta) * (2 - t) / 2 + (1 - t) * np.exp(1j * theta) * 2 / (1 + q) * Omega / frac_roverR ) gamma = gamma1 + 1j * gamma2 aa -= np.abs(gamma) ** 2 bb -= 2 * cdot(gamma, gammaint_fac) usol = np.array(solvequadeq(cc, bb, aa)).T xcr_4, ycr_4 = pol_to_cart(b * usol[:, 1] ** (-1 / t) * frac_roverR, theta) if ( t > 1 ): # If t>1, get the approximate outer caustic instead (where inverse magnification = maginf). usol = np.array(solvequadeq(cc, bb, aa - maginf)).T xcr_cut, ycr_cut = pol_to_cart(b * usol[:, 1] ** (-1 / t) * frac_roverR, theta) else: usol = np.array(solvequadeq(cc, bb, aa + maginf)).T xcr_cut, ycr_cut = pol_to_cart(b * usol[:, 0] ** (-1 / t) * frac_roverR, theta) al_cut = _alpha_epl_shear(xcr_cut, ycr_cut, b, q, t, gamma1, gamma2, Omega=Omega) al_4 = _alpha_epl_shear(xcr_4, ycr_4, b, q, t, gamma1, gamma2, Omega=Omega) if sourceplane: xca_cut, yca_cut = xcr_cut - al_cut.real, ycr_cut - al_cut.imag xca_4, yca_4 = xcr_4 - al_4.real, ycr_4 - al_4.imag else: xca_cut, yca_cut = xcr_cut, ycr_cut xca_4, yca_4 = xcr_4, ycr_4 if return_which == "caustic": return M @ (xca_4, yca_4) + cen if return_which == "cut": return M @ (xca_cut, yca_cut) + cen rcut, thcut = cart_to_pol(xca_cut, yca_cut) r, th = cart_to_pol(xca_4, yca_4) r2 = np.interp(th, thcut, rcut, period=2 * np.pi) if return_which == "double": r = np.fmax(r, r2) else: # Quad r = np.fmin(r, r2) pos_tosample = np.empty((2, num_th)) pos_tosample[0], pos_tosample[1] = pol_to_cart(r, th) if return_which in ("double", "quad"): return M @ pos_tosample + cen return ( M @ (xca_4, yca_4) + cen, M @ (xca_cut, yca_cut) + cen, M @ pos_tosample + cen, ) # Mostly for some backward compatibility