Source code for mlcg_tk.input_generator.prior_fit.restricted_bending

import torch
import numpy as np
from typing import Dict, Tuple, List
from scipy.optimize import curve_fit, fsolve
from scipy.integrate import trapezoid


[docs] def restricted_quartic_angle(x, a, b, c, d, k, v_0): r""" Quartic angle potential with repulsive 1/sin^2 term to avoid singularities at 0 and pi angles. .. math:: V(\theta) = a cos(\theta)^4 + b cos(\theta)^3 + c cos(\theta)^2 + d cos(\theta) + \frac{k}{\sin^2(\theta)} + V0 """ cos = torch.cos(x) sin = torch.sin(x) quart = a * cos**4 + b * cos**3 + c * cos**2 + d * cos rep = k / (sin**2) V = quart + rep + v_0 return V
[docs] def dx_restricted_quartic_angle_numpy(x, a, b, c, d, k): """Derivative of the restricted quartic angle potential with respect to x=theta (numpy version for fsolve)""" cos = np.cos(x) sin = np.sin(x) dquart = -4 * a * sin * cos**3 - 3 * b * sin * cos**2 - 2 * c * sin * cos - d * sin drep = -k * (2 * cos) / (sin**3) dV = dquart + drep return dV
[docs] def ddx2_restricted_quartic_angle_numpy(x, a, b, c, d, k): """Second derivative of the restricted quartic angle potential with respect to x=theta (numpy version for fsolve)""" cos = np.cos(x) sin = np.sin(x) ddquart = ( -4 * a * (cos**4 - 3 * sin**2 * cos**2) - 3 * b * (cos**3 - 2 * sin**2 * cos) - 2 * c * (cos**2 - sin**2) - d * cos ) ddrep = 2 * k * (1 / sin**2 + 3 * cos**2 / sin**4) ddV = ddquart + ddrep return ddV
[docs] def find_extrema_in_range( a, b, c, d, k, x_range: Tuple[float, float] ) -> Tuple[List[float], List[float]]: """Find minima and maxima in a specified range. Args: a, b, c, d, k: potential parameters x_range: tuple (x_min, x_max) for search region Returns: tuple: (minima, maxima) - lists of x values where extrema occur """ minima = [] maxima = [] search_min = max(x_range[0], 0.01) search_max = min(x_range[1], 3.13) if search_max <= search_min: return minima, maxima for x0 in np.linspace(search_min, search_max, 20): try: root = fsolve( dx_restricted_quartic_angle_numpy, x0, args=(a, b, c, d, k), full_output=True, ) x_root = root[0][0] info = root[1] # Check if fsolve converged and root is in valid range if info["fvec"][0] ** 2 < 1e-6 and x_range[0] <= x_root <= x_range[1]: d2V = ddx2_restricted_quartic_angle_numpy(x_root, a, b, c, d, k) if d2V == 0: continue elif d2V > 0: if not any(abs(x_root - m) < 1e-4 for m in minima): minima.append(x_root) elif d2V < 0: if not any(abs(x_root - m) < 1e-4 for m in maxima): maxima.append(x_root) except RuntimeError: print(f"Error finding extrema at x0={x0:.6f}") continue return sorted(minima), sorted(maxima)
[docs] def find_minmax( params, left_region: Tuple[float, float], right_region: Tuple[float, float] ) -> Tuple[List[float], List[float], List[float], List[float]]: """Find minima and maxima of the potential in specified regions. Args: params: (a, b, c, d, k, v_0) parameters left_region: tuple (x_min, x_max) for left tail region right_region: tuple (x_min, x_max) for right tail region Returns: tuple: (left_minima, left_maxima, right_minima, right_maxima) """ a, b, c, d, k, _ = params left_minima, left_maxima = find_extrema_in_range(a, b, c, d, k, left_region) right_minima, right_maxima = find_extrema_in_range(a, b, c, d, k, right_region) return left_minima, left_maxima, right_minima, right_maxima
[docs] def fit_rb_from_potential_estimates( bin_centers_nz: torch.Tensor, dG_nz: torch.Tensor, **kwargs ) -> Dict: r"""Fits restricted quartic angle potential. If minima or maxima are found in the tail regions (outside the non-zero data range), refits with a=0, b=0 to collapse to quadratic.""" integral = torch.tensor( float(trapezoid(dG_nz.cpu().numpy(), bin_centers_nz.cpu().numpy())) ) mask = torch.abs(dG_nz) > 1e-4 * torch.abs(integral) nonzero_indices = torch.where(mask)[0] first_nonzero_x = bin_centers_nz[nonzero_indices[0]].item() last_nonzero_x = bin_centers_nz[nonzero_indices[-1]].item() # Define tail regions (outside the data range) left_tail = (0.0, first_nonzero_x) right_tail = (last_nonzero_x, np.pi) try: popt, _ = curve_fit( restricted_quartic_angle, bin_centers_nz[mask], dG_nz[mask], p0=[1, 0, 0, 0, 1e-3, torch.argmin(dG_nz[mask])], bounds=( (0, -np.inf, -np.inf, -np.inf, 1e-5, -np.inf), (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf), ), maxfev=5000, ) left_minima, left_maxima, right_minima, right_maxima = find_minmax( popt, left_tail, right_tail ) has_tail_extrema = ( len(left_minima) > 0 or len(left_maxima) > 0 or len(right_minima) > 0 or len(right_maxima) > 0 ) if has_tail_extrema: extrema_info = [] if left_minima: extrema_info.append(f"left minima: {left_minima}") if left_maxima: extrema_info.append(f"left maxima: {left_maxima}") if right_minima: extrema_info.append(f"right minima: {right_minima}") if right_maxima: extrema_info.append(f"right maxima: {right_maxima}") print( f"Extrema found in tail regions ({', '.join(extrema_info)}). Refitting with a=0, b=0" ) # Refit with a=0 and b=0 (collapse to quadratic + repulsive term) def restricted_quartic_angle_constrained(x, c, d, k, v_0): return restricted_quartic_angle(x, 0, 0, c, d, k, v_0) popt_constrained, _ = curve_fit( restricted_quartic_angle_constrained, bin_centers_nz[mask], dG_nz[mask], p0=[0, 0, 1e-3, torch.argmin(dG_nz[mask])], bounds=( (-np.inf, -np.inf, 1e-5, -np.inf), (np.inf, np.inf, np.inf, np.inf), ), maxfev=5000, ) stat = { "a": torch.tensor(0.0), "b": torch.tensor(0.0), "c": popt_constrained[0], "d": popt_constrained[1], "k": popt_constrained[2], "v_0": popt_constrained[3], } else: stat = { "a": popt[0], "b": popt[1], "c": popt[2], "d": popt[3], "k": popt[4], "v_0": popt[5], } except Exception as e: print(f"Failed to fit potential estimate for RestrictedQuartic angle: {e}") stat = { "a": torch.tensor(float("nan")), "b": torch.tensor(float("nan")), "c": torch.tensor(float("nan")), "d": torch.tensor(float("nan")), "k": torch.tensor(float("nan")), "v_0": torch.tensor(float("nan")), } return stat