Source code for lenstronomy.ImSim.SourceReconstruction.solve_regularization_strength

import numpy as np
from typing import Callable


[docs] def d_log_evi_d_lambda(l: float, U: np.ndarray, M: np.ndarray, b: np.ndarray) -> float: """Computes the derivative of the logarithm of the Bayesian evidence with respect to the regularization strength (lambda, l). This function calculates the derivative as: d(ln(Evidence))/d(lambda) ~ N_s/lambda - tr[(M+lambda*U)^-1 * U] - b^T * (M+lambda*U)^-1 * U * (M+lambda*U)^-1 * b Where: - N_s: Number of source pixels (U.shape[0]) - lambda: The regularization strength - U: The regularization matrix - M: The M matrix - b: The b vector :param l: The current value of the regularization strength (lambda). :param U: The regularization matrix (numpy.ndarray). :param M: The M matrix (numpy.ndarray). :param b: The b vector (numpy.ndarray). :return: The computed derivative value (float). """ N_source = U.shape[0] lambda_U = l * U M_plus_lambda_U = M + lambda_U # Compute the inverse of (M + lambda * U) M_plus_lambda_U_inv = np.linalg.inv(M_plus_lambda_U) # Compute the trace term: tr[(M+lambda*U)^-1 * U] trace_term_matrix = np.matmul(M_plus_lambda_U_inv, U) trace_term = np.trace(trace_term_matrix) # Compute the quadratic term: b^T * (M+lambda*U)^-1 * U * (M+lambda*U)^-1 * b M_plus_lambda_U_inv_b = np.matmul(M_plus_lambda_U_inv, b) U_times_M_plus_lambda_U_inv_b = np.matmul(U, M_plus_lambda_U_inv_b) # Using np.sum(v1 * v2) is equivalent to v1^T @ v2 for 1D vectors quadratic_term = np.sum(M_plus_lambda_U_inv_b * U_times_M_plus_lambda_U_inv_b) derivative_value = (N_source / l) - trace_term - quadratic_term return derivative_value
[docs] def solve_optimal_lambda( derivative_function: Callable[[float, np.ndarray, np.ndarray, np.ndarray], float], U: np.ndarray, M: np.ndarray, b: np.ndarray, initial_lower_bound: float, initial_upper_bound: float, tolerance: float = 1e-7, max_iterations: int = 20, check_initial_bounds: bool = True, ) -> float: """Finds the optimal regularization strength (lambda) by solving for the root of the log-evidence derivative using a bisection method. The optimal lambda is typically the value where the derivative of the log-evidence is zero. This function assumes that the derivative `d(ln(Evidence))/d(lambda)` is monotonically decreasing and crosses zero within the specified bounds. :param derivative_function: A callable function that computes the derivative d(ln(Evidence))/d(lambda). It must accept (regularization_strength, data_matrix, regularization_matrix, data_vector) as its arguments. :param U: The regularization matrix (numpy.ndarray). :param M: The M matrix (numpy.ndarray). :param b: The b vector (numpy.ndarray). :param initial_lower_bound: The lower bound for the search range of lambda. It is expected that `derivative_function(initial_lower_bound, ...)` > 0. :param initial_upper_bound: The upper bound for the search range of lambda. It is expected that `derivative_function(initial_upper_bound, ...)` < 0. :param tolerance: float, The desired absolute tolerance for the lambda value. The search stops when the width of the search interval is less than this value. Defaults to 1e-7. :param max_iterations: int, The maximum number of bisection iterations to perform. Defaults to 20. :param check_initial_bounds: bool, If True, perform checks to ensure that `initial_lower_bound` < `initial_upper_bound` and that the derivative function returns the expected signs at the boundaries (positive at lower, negative at upper). Setting this to False can speed up repeated calls if the bounds are guaranteed to be valid, but disables critical error checking. Defaults to True. :return: float, The optimized regularization strength (lambda) that maximizes the log-evidence. :raises ValueError: If `check_initial_bounds` is True and `initial_lower_bound` is not strictly less than `initial_upper_bound`, or if the derivative function does not yield the expected signs at the initial bounds (i.e., the root is not bracketed). """ if check_initial_bounds: if not (initial_lower_bound < initial_upper_bound): raise ValueError( "`initial_lower_bound` must be strictly less than `initial_upper_bound`." ) # Check initial conditions to ensure the root is bracketed # For a monotonically decreasing derivative crossing zero: # derivative at lower bound should be positive # derivative at upper bound should be negative derivative_at_lower_bound = derivative_function(initial_lower_bound, U, M, b) derivative_at_upper_bound = derivative_function(initial_upper_bound, U, M, b) if derivative_at_lower_bound <= 0: raise ValueError( f"Derivative at `initial_lower_bound` ({initial_lower_bound}) is {derivative_at_lower_bound} " f"(expected > 0). The root might not be bracketed correctly or is outside this range." ) if derivative_at_upper_bound >= 0: raise ValueError( f"Derivative at `initial_upper_bound` ({initial_upper_bound}) is {derivative_at_upper_bound} " f"(expected < 0). The root might not be bracketed correctly or is outside this range." ) current_lower_bound = initial_lower_bound current_upper_bound = initial_upper_bound for iteration_count in range(max_iterations): # Check for convergence based on interval width if np.abs(current_upper_bound - current_lower_bound) < tolerance: break mid_point_lambda = (current_upper_bound + current_lower_bound) / 2 derivative_at_mid_point = derivative_function(mid_point_lambda, U, M, b) if derivative_at_mid_point < 0: # The root is in the lower half of the current interval current_upper_bound = mid_point_lambda elif derivative_at_mid_point > 0: # The root is in the upper half of the current interval current_lower_bound = mid_point_lambda # Return the midpoint of the final interval as the approximate optimal lambda return (current_lower_bound + current_upper_bound) / 2