Source code for lenstronomy.Plots.chain_plot

import copy

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable

import sys
from typing import TYPE_CHECKING

if sys.version_info >= (3, 12):  # pragma: no cover
    from typing import Unpack
else:  # pragma: no cover
    try:  # pragma: no cover
        from typing_extensions import Unpack
    except ImportError:  # pragma: no cover
        pass

if TYPE_CHECKING:  # pragma: no cover
    from lenstronomy.Plots import plot_util

from lenstronomy.Util.package_util import exporter

export, __all__ = exporter()


[docs] @export def plot_chain_list(chain_list, index=0, num_average=100): """Plots the output of a chain of samples (MCMC or PSO) with the some diagnostics of convergence. This routine is an example and more tests might be appropriate to analyse a specific chain. :param chain_list: Chains with arguments [type string, samples etc...] :type chain_list: list :param index: index of chain to be plotted :type index: int :param num_average: in chains, number of steps to average over in plotting diagnostics :type num_average: int :return: plotting instance figure, axes (potentially multiple) """ chain_i = chain_list[index] chain_type = chain_i[0] if chain_type == "PSO": chain, param = chain_i[1:] f, axes = plot_chain(chain, param) elif chain_type in ["MCMC", "emcee", "zeus"]: samples, param, dist = chain_i[1:] f, ax = plt.subplots(1, 1, figsize=(6, 6)) axes = plot_mcmc_behaviour(ax, samples, param, dist, num_average=num_average) elif chain_type in ["dynesty", "dyPolyChord", "MultiNest"]: samples, param, dist = chain_i[1:4] f, ax = plt.subplots(1, 1, figsize=(6, 6)) axes = plot_mcmc_behaviour(ax, samples, param, dist, num_average=num_average) else: raise ValueError("chain_type %s not supported for plotting" % chain_type) return f, axes
[docs] @export def plot_chain(chain, param_list): """Plot PSO chain diagnostics. :param chain: chi2, position, and velocity history :type chain: tuple :param param_list: Parameter names :type param_list: list :return: plotting instance figure and axes """ chi2_list, pos_list, vel_list = chain f, axes = plt.subplots(1, 3, figsize=(18, 6)) ax = axes[0] ax.plot(np.log10(-np.array(chi2_list))) ax.set_title("-logL") ax = axes[1] pos = np.array(pos_list) vel = np.array(vel_list) n_iter = len(pos) plt.figure() for i in range(0, len(pos[0])): ax.plot( (pos[:, i] - pos[n_iter - 1, i]) / (pos[n_iter - 1, i] + 1), label=param_list[i], ) ax.set_title("particle position") ax.legend() ax = axes[2] for i in range(0, len(vel[0])): ax.plot(vel[:, i] / (pos[n_iter - 1, i] + 1), label=param_list[i]) ax.set_title("param velocity") ax.legend() return f, axes
[docs] @export def plot_mcmc_behaviour(ax, samples_mcmc, param_mcmc, dist_mcmc=None, num_average=100): """Plots the MCMC behaviour and looks for convergence of the chain. :param ax: Matplotlib axes instance :type ax: matplotlib.axes.Axes :param samples_mcmc: sampled parameters :type samples_mcmc: numpy.ndarray :param param_mcmc: Parameters :type param_mcmc: list :param dist_mcmc: log likelihood of the chain :type dist_mcmc: numpy.ndarray or None :param num_average: number of samples to average (should coincide with the number of samples in the emcee process) :type num_average: int :return: """ num_samples = len(samples_mcmc[:, 0]) num_average = int(num_average) n_points = int((num_samples - num_samples % num_average) / num_average) for i, param_name in enumerate(param_mcmc): samples = samples_mcmc[:, i] samples_averaged = np.average( samples[: int(n_points * num_average)].reshape(n_points, num_average), axis=1, ) end_point = np.mean(samples_averaged) samples_renormed = (samples_averaged - end_point) / np.std(samples_averaged) ax.plot(samples_renormed, label=param_name) if dist_mcmc is not None: dist_averaged = -np.max( dist_mcmc[: int(n_points * num_average)].reshape(n_points, num_average), axis=1, ) dist_normed = (dist_averaged - np.max(dist_averaged)) / ( np.max(dist_averaged) - np.min(dist_averaged) ) ax.plot(dist_normed, label="logL", color="k", linewidth=2) ax.legend() return ax
[docs] @export def psf_iteration_compare( kwargs_psf, **kwargs_matshow: "Unpack[plot_util.MatshowKwargs]" ): """Compare initial and iteratively reconstructed PSF kernels. :param kwargs_psf: keyword arguments that initiate a PSF() class :type kwargs_psf: dict :param kwargs_matshow: kwargs to send to matplotlib.pyplot.matshow() :return: """ psf_out = kwargs_psf["kernel_point_source"] psf_in = kwargs_psf["kernel_point_source_init"] # psf_error_map = kwargs_psf.get('psf_error_map', None) from lenstronomy.Data.psf import PSF psf = PSF(**kwargs_psf) # psf_out = psf.kernel_point_source psf_variance_map = psf.psf_variance_map n_kernel = len(psf_in) delta_x = n_kernel / 20.0 delta_y = n_kernel / 10.0 kwargs_matshow.setdefault("cmap", "seismic") n = 3 if psf_variance_map is not None: n += 1 f, axes = plt.subplots(1, n, figsize=(5 * n, 5)) ax = axes[0] im = ax.matshow(np.log10(psf_in), origin="lower", **kwargs_matshow) vmin, vmax = im.get_clim() kwargs_matshow.setdefault("vmin", vmin) kwargs_matshow.setdefault("vmax", vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.text( delta_x, n_kernel - delta_y, "Initial PSF model", color="k", fontsize=20, backgroundcolor="w", ) ax = axes[1] im = ax.matshow(np.log10(psf_out), origin="lower", **kwargs_matshow) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.text( delta_x, n_kernel - delta_y, "iterative reconstruction", color="k", fontsize=20, backgroundcolor="w", ) ax = axes[2] kwargs_new = copy.deepcopy(kwargs_matshow) kwargs_new.pop("vmin", None) kwargs_new.pop("vmax", None) im = ax.matshow( psf_out - psf_in, origin="lower", vmin=-(10**-3), vmax=10**-3, **kwargs_new ) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.text( delta_x, n_kernel - delta_y, "difference", color="k", fontsize=20, backgroundcolor="w", ) if psf_variance_map is not None: ax = axes[3] im = ax.matshow( np.log10(psf_variance_map * psf.kernel_point_source**2), origin="lower", **kwargs_matshow ) n_kernel = len(psf_variance_map) delta_x = n_kernel / 20.0 delta_y = n_kernel / 10.0 divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax.text( delta_x, n_kernel - delta_y, "psf variance map", color="k", fontsize=20, backgroundcolor="w", ) f.tight_layout() return f, axes