Source code for mqed.plotting.plot_sqrt_msd

from __future__ import annotations
from pathlib import Path
import numpy as np
import h5py
import matplotlib.pyplot as plt
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
from loguru import logger

from mqed.utils.file_utils import _resolve_input_path

from mqed.utils.logging_utils import setup_loggers_hydra_aware
from mqed.utils.hydra_local import prepare_hydra_config_path


def _to_plot_time(t_ps: np.ndarray, cfg_ps) -> np.ndarray:
    unit = str(getattr(cfg_ps, "time_unit", "ps")).lower()
    if unit == "ps":
        return t_ps
    if unit == "fs":
        return t_ps * 1.0e3
    if unit == "s":
        return t_ps * 1.0e-12
    raise ValueError(f"Unsupported plot_settings.time_unit='{unit}'. Use 'fs', 'ps', or 's'.")


def _msd_analytical_local(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, sigma_j_eV: float) -> np.ndarray:
    j_eff_sq = j_0_eV ** 2 + sigma_j_eV ** 2
    return 2.0 * a ** 2 * j_eff_sq * t_fs ** 2 / (hbar_eV_fs ** 2)


def _x_square_analytical_gaussian(
    t_fs: np.ndarray,
    a: float,
    hbar_eV_fs: float,
    j_0_eV: float,
    sigma_j_eV: float,
    k_parallel: float,
    sigma_sites: float,
) -> np.ndarray:
    term1 = sigma_sites ** 2 / 2.0
    term2 = (4.0 * j_0_eV ** 2 / hbar_eV_fs ** 2) * np.sin(k_parallel * a) ** 2
    term3 = 2.0 * sigma_j_eV ** 2 / hbar_eV_fs ** 2
    return a ** 2 * (term1 + (term2 + term3) * t_fs ** 2)


def _position_analytical_gaussian(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, k_parallel: float) -> np.ndarray:
    velocity_prefactor = (-2.0 * j_0_eV / hbar_eV_fs) * np.sin(k_parallel * a)
    return a * velocity_prefactor * t_fs


def _nn_rmsd_analytical(t_fs: np.ndarray, model: str, params: dict) -> np.ndarray:
    a = float(params.get("a", 1.0))
    hbar_eV_fs = float(params.get("hbar_eV_fs", 0.6582119514))
    j_0_eV = float(params["J_0_eV"])
    sigma_j_eV = float(params["sigma_J_eV"])
    if model == "local_excitation":
        msd = _msd_analytical_local(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV)
        return np.sqrt(np.maximum(0.0, msd))
    if model == "gaussian_wave":
        k_parallel = float(params["k_parallel"])
        sigma_sites = float(params["sigma_sites"])
        # RMSD = sqrt(MSD) = sqrt(x2-x**2) in our definition
        # Note: x2 - <x>^2 would be the variance.
        x2 = _x_square_analytical_gaussian(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV, k_parallel, sigma_sites)
        x = _position_analytical_gaussian(t_fs, a, hbar_eV_fs, j_0_eV, k_parallel)
        variance = np.maximum(0.0, x2 - x ** 2)
        return np.sqrt(variance)
    raise ValueError(f"Unsupported analytical model '{model}'. Use 'local_excitation' or 'gaussian_wave'.")

[docs] def _load_dx_and_time(h5_path: Path) -> tuple[np.ndarray, np.ndarray, dict]: """ Returns: t_ps: (T,) dx_nm: (T,) (mean if available; otherwise single-run Δx) meta: dict with info about what we loaded Supports: - datasets: 'dx_mean_nm' (preferred), 'dx_nm', or expectations: X_shift, X_shift2 (compute Δx) """ logger.info(f"Loading Δx data from {h5_path}") meta = {} with h5py.File(str(h5_path), "r") as f: # time t_ps_ds = f.get("t_ps") if not isinstance(t_ps_ds, h5py.Dataset): raise ValueError(f"{h5_path} has no 't_ps' dataset.") t_ps = np.asarray(t_ps_ds[...]).ravel() # breakpoint() # Δx directly? dx_mean_ds = f.get("dx_mean_nm") dx_ds = f.get("dx_nm") ex_group = f.get("expectations") if isinstance(dx_mean_ds, h5py.Dataset): dx = np.asarray(dx_mean_ds[...]).ravel() meta["source"] = "dx_mean_nm" elif isinstance(dx_ds, h5py.Dataset): dx = np.asarray(dx_ds[...]).ravel() meta["source"] = "dx_nm" else: # try expectations group if isinstance(ex_group, h5py.Group) and "X_shift" in ex_group and "X_shift2" in ex_group: ex1_ds = ex_group.get("X_shift") ex2_ds = ex_group.get("X_shift2") if not isinstance(ex1_ds, h5py.Dataset) or not isinstance(ex2_ds, h5py.Dataset): raise ValueError(f"{h5_path} has invalid expectations/X_shift or X_shift2 dataset.") ex1 = np.asarray(ex1_ds[...]).ravel() ex2 = np.asarray(ex2_ds[...]).ravel() dx = np.sqrt(np.maximum(0.0, ex2 - ex1**2)) meta["source"] = "expectations/X_shift,X_shift2" elif isinstance(ex_group, h5py.Group) and "x2_mean" in ex_group and "position_mean" in ex_group: x2_ds = ex_group.get("x2_mean") pos_ds = ex_group.get("position_mean") if not isinstance(x2_ds, h5py.Dataset) or not isinstance(pos_ds, h5py.Dataset): raise ValueError(f"{h5_path} has invalid expectations/x2_mean or position_mean dataset.") x2_mean = np.asarray(x2_ds[...]).ravel() position_mean = np.asarray(pos_ds[...]).ravel() dx = np.sqrt(np.maximum(0.0, x2_mean - position_mean**2)) meta["source"] = "expectations/x2_mean,position_mean" elif isinstance(ex_group, h5py.Group) and "msd_mean" in ex_group and "position_mean" in ex_group: msd_ds = ex_group.get("msd_mean") pos_ds = ex_group.get("position_mean") if not isinstance(msd_ds, h5py.Dataset): raise ValueError(f"{h5_path} has invalid expectations/msd_mean dataset.") msd_mean = np.asarray(msd_ds[...]).ravel() position_mean = np.asarray(pos_ds[...]).ravel() dx = np.sqrt(np.maximum(0.0, msd_mean - position_mean**2)) meta["source"] = "expectations/msd_mean" else: raise ValueError( f"{h5_path} does not contain 'dx_mean_nm'/'dx_nm' " "or expectations/['X_shift','X_shift2'] or 'msd_mean' or 'x2_mean' with optional 'position_mean'." ) # Optional attrs for k in ( "method", "n_realizations", "sigma_phi_deg", "seed_base", "state_format", "J_0_eV", "sigma_J_eV", "k_parallel", "sigma_sites", "eps_0_eV", ): if k in f.attrs: meta[k] = f.attrs[k] if dx.shape != t_ps.shape: raise ValueError(f"{h5_path} Δx shape {dx.shape} and t_ps shape {t_ps.shape} mismatch.") return t_ps, dx, meta
[docs] def _select_x(t_ps: np.ndarray, cfg_ps) -> np.ndarray: """Return boolean mask for x selection by index or by time value (ps).""" if hasattr(cfg_ps, "x_index_range") and cfg_ps.x_index_range: i0, i1 = int(cfg_ps.x_index_range[0]), int(cfg_ps.x_index_range[1]) sel = np.zeros_like(t_ps, dtype=bool) sel[max(0, i0): min(len(t_ps), i1 + 1)] = True return sel if hasattr(cfg_ps, "x_range") and cfg_ps.x_range: t_plot = _to_plot_time(t_ps, cfg_ps) xmin, xmax = float(cfg_ps.x_range[0]), float(cfg_ps.x_range[1]) return (t_plot >= xmin) & (t_plot <= xmax) if hasattr(cfg_ps, "x_range_ps") and cfg_ps.x_range_ps: xmin, xmax = float(cfg_ps.x_range_ps[0]), float(cfg_ps.x_range_ps[1]) return (t_ps >= xmin) & (t_ps <= xmax) if hasattr(cfg_ps, "x_range_fs") and cfg_ps.x_range_fs: t_fs = t_ps * 1.0e3 xmin, xmax = float(cfg_ps.x_range_fs[0]), float(cfg_ps.x_range_fs[1]) return (t_fs >= xmin) & (t_fs <= xmax) return np.ones_like(t_ps, dtype=bool)
HYDRA_CONFIG_PATH: str = prepare_hydra_config_path("plots", __file__) @hydra.main(config_path=HYDRA_CONFIG_PATH, config_name="sqrt_msd", version_base=None) def main(cfg: DictConfig) -> None: outdir = Path(HydraConfig.get().runtime.output_dir) setup_loggers_hydra_aware() ps = cfg.plot_settings fig, ax = plt.subplots(figsize=(ps.figsize[0], ps.figsize[1]) if getattr(ps, "figsize", None) else (7, 5)) loaded_curves = [] for curve in cfg.curves: path = _resolve_input_path(curve) t_ps, dx_nm, meta = _load_dx_and_time(path) sel = _select_x(t_ps, ps) x = _to_plot_time(t_ps[sel], ps) * getattr(ps, "x_scale_factor", 1.0) y = dx_nm[sel] # style style = getattr(curve, "style", "-") lw = getattr(curve, "lw", ps.get("lw", 1.5)) label = getattr(curve, "label", path.stem) color = getattr(curve, "color", None) ax.plot(x, y, style, lw=lw, label=label, color=color) loaded_curves.append({"t_ps": t_ps, "sel": sel, "meta": meta}) logger.info(f"Plotted {label} from {path.name} (source={meta.get('source','?')})") if bool(getattr(ps, "enable_analytical_curves", True)): for curve in getattr(cfg, "analytical_curves", []): if not loaded_curves: raise ValueError("analytical_curves requires at least one numerical curve to define time grid.") from_curve_index = int(getattr(curve, "from_curve_index", 0)) if from_curve_index < 0 or from_curve_index >= len(loaded_curves): raise ValueError( f"analytical from_curve_index={from_curve_index} out of range [0, {len(loaded_curves)-1}]." ) ref = loaded_curves[from_curve_index] t_ps_ref = np.asarray(ref["t_ps"]) sel = np.asarray(ref["sel"], dtype=bool) meta = dict(ref["meta"]) params = dict(getattr(curve, "params", {}) or {}) for key in ("J_0_eV", "sigma_J_eV", "k_parallel", "sigma_sites"): if key not in params and key in meta: params[key] = meta[key] model = str(getattr(curve, "model", "gaussian_wave")) t_fs = t_ps_ref[sel] * 1.0e3 y = _nn_rmsd_analytical(t_fs, model, params) x = _to_plot_time(t_ps_ref[sel], ps) * getattr(ps, "x_scale_factor", 1.0) style = getattr(curve, "style", "-") lw = getattr(curve, "lw", ps.get("lw", 1.5)) label = getattr(curve, "label", f"Analytical RMSD ({model})") color = getattr(curve, "color", None) ax.plot(x, y, style, lw=lw, label=label, color=color) logger.info( f"Plotted analytical curve: {label} (model={model}, from_curve_index={from_curve_index})" ) # labels and title ax.set_xlabel(ps.xlabel) ax.set_ylabel(ps.ylabel) if getattr(ps, "title", None): ax.set_title(ps.title) # scales if getattr(ps, "xscale", None): ax.set_xscale(ps.xscale) if getattr(ps, "yscale", None): ax.set_yscale(ps.yscale) # limits if getattr(ps, "xlim", None): ax.set_xlim(ps.xlim[0], ps.xlim[1]) if getattr(ps, "ylim", None): ax.set_ylim(ps.ylim[0], ps.ylim[1]) if getattr(ps, "grid", True): ax.grid(True, which="both", ls="--", alpha=0.5) if getattr(ps, "legend", True): ax.legend() if getattr(ps, "tight_layout", True): plt.tight_layout() if getattr(ps, "save_plot", True): name = getattr(ps, "filename", "sqrt_msd.png") figpath = outdir / name fig.savefig(figpath, dpi=getattr(ps, "dpi", 300), bbox_inches="tight") logger.success(f"Saved plot → {figpath}") if getattr(ps, "show", False): plt.show() plt.close(fig) if __name__ == "__main__": main()