from __future__ import annotations
from pathlib import Path
import numpy as np
import h5py
import matplotlib.pyplot as plt
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
from loguru import logger
from mqed.utils.file_utils import _resolve_input_path
from mqed.utils.logging_utils import setup_loggers_hydra_aware
from mqed.utils.hydra_local import prepare_hydra_config_path
def _to_plot_time(t_ps: np.ndarray, cfg_ps) -> np.ndarray:
unit = str(getattr(cfg_ps, "time_unit", "ps")).lower()
if unit == "ps":
return t_ps
if unit == "fs":
return t_ps * 1.0e3
if unit == "s":
return t_ps * 1.0e-12
raise ValueError(f"Unsupported plot_settings.time_unit='{unit}'. Use 'fs', 'ps', or 's'.")
def _msd_analytical_local(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, sigma_j_eV: float) -> np.ndarray:
j_eff_sq = j_0_eV ** 2 + sigma_j_eV ** 2
return 2.0 * a ** 2 * j_eff_sq * t_fs ** 2 / (hbar_eV_fs ** 2)
def _x_square_analytical_gaussian(
t_fs: np.ndarray,
a: float,
hbar_eV_fs: float,
j_0_eV: float,
sigma_j_eV: float,
k_parallel: float,
sigma_sites: float,
) -> np.ndarray:
term1 = sigma_sites ** 2 / 2.0
term2 = (4.0 * j_0_eV ** 2 / hbar_eV_fs ** 2) * np.sin(k_parallel * a) ** 2
term3 = 2.0 * sigma_j_eV ** 2 / hbar_eV_fs ** 2
return a ** 2 * (term1 + (term2 + term3) * t_fs ** 2)
def _position_analytical_gaussian(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, k_parallel: float) -> np.ndarray:
velocity_prefactor = (-2.0 * j_0_eV / hbar_eV_fs) * np.sin(k_parallel * a)
return a * velocity_prefactor * t_fs
def _nn_rmsd_analytical(t_fs: np.ndarray, model: str, params: dict) -> np.ndarray:
a = float(params.get("a", 1.0))
hbar_eV_fs = float(params.get("hbar_eV_fs", 0.6582119514))
j_0_eV = float(params["J_0_eV"])
sigma_j_eV = float(params["sigma_J_eV"])
if model == "local_excitation":
msd = _msd_analytical_local(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV)
return np.sqrt(np.maximum(0.0, msd))
if model == "gaussian_wave":
k_parallel = float(params["k_parallel"])
sigma_sites = float(params["sigma_sites"])
# RMSD = sqrt(MSD) = sqrt(x2-x**2) in our definition
# Note: x2 - <x>^2 would be the variance.
x2 = _x_square_analytical_gaussian(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV, k_parallel, sigma_sites)
x = _position_analytical_gaussian(t_fs, a, hbar_eV_fs, j_0_eV, k_parallel)
variance = np.maximum(0.0, x2 - x ** 2)
return np.sqrt(variance)
raise ValueError(f"Unsupported analytical model '{model}'. Use 'local_excitation' or 'gaussian_wave'.")
[docs]
def _load_dx_and_time(h5_path: Path) -> tuple[np.ndarray, np.ndarray, dict]:
"""
Returns:
t_ps: (T,)
dx_nm: (T,) (mean if available; otherwise single-run Δx)
meta: dict with info about what we loaded
Supports:
- datasets: 'dx_mean_nm' (preferred), 'dx_nm', or expectations: X_shift, X_shift2 (compute Δx)
"""
logger.info(f"Loading Δx data from {h5_path}")
meta = {}
with h5py.File(str(h5_path), "r") as f:
# time
t_ps_ds = f.get("t_ps")
if not isinstance(t_ps_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has no 't_ps' dataset.")
t_ps = np.asarray(t_ps_ds[...]).ravel()
# breakpoint()
# Δx directly?
dx_mean_ds = f.get("dx_mean_nm")
dx_ds = f.get("dx_nm")
ex_group = f.get("expectations")
if isinstance(dx_mean_ds, h5py.Dataset):
dx = np.asarray(dx_mean_ds[...]).ravel()
meta["source"] = "dx_mean_nm"
elif isinstance(dx_ds, h5py.Dataset):
dx = np.asarray(dx_ds[...]).ravel()
meta["source"] = "dx_nm"
else:
# try expectations group
if isinstance(ex_group, h5py.Group) and "X_shift" in ex_group and "X_shift2" in ex_group:
ex1_ds = ex_group.get("X_shift")
ex2_ds = ex_group.get("X_shift2")
if not isinstance(ex1_ds, h5py.Dataset) or not isinstance(ex2_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/X_shift or X_shift2 dataset.")
ex1 = np.asarray(ex1_ds[...]).ravel()
ex2 = np.asarray(ex2_ds[...]).ravel()
dx = np.sqrt(np.maximum(0.0, ex2 - ex1**2))
meta["source"] = "expectations/X_shift,X_shift2"
elif isinstance(ex_group, h5py.Group) and "x2_mean" in ex_group and "position_mean" in ex_group:
x2_ds = ex_group.get("x2_mean")
pos_ds = ex_group.get("position_mean")
if not isinstance(x2_ds, h5py.Dataset) or not isinstance(pos_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/x2_mean or position_mean dataset.")
x2_mean = np.asarray(x2_ds[...]).ravel()
position_mean = np.asarray(pos_ds[...]).ravel()
dx = np.sqrt(np.maximum(0.0, x2_mean - position_mean**2))
meta["source"] = "expectations/x2_mean,position_mean"
elif isinstance(ex_group, h5py.Group) and "msd_mean" in ex_group and "position_mean" in ex_group:
msd_ds = ex_group.get("msd_mean")
pos_ds = ex_group.get("position_mean")
if not isinstance(msd_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/msd_mean dataset.")
msd_mean = np.asarray(msd_ds[...]).ravel()
position_mean = np.asarray(pos_ds[...]).ravel()
dx = np.sqrt(np.maximum(0.0, msd_mean - position_mean**2))
meta["source"] = "expectations/msd_mean"
else:
raise ValueError(
f"{h5_path} does not contain 'dx_mean_nm'/'dx_nm' "
"or expectations/['X_shift','X_shift2'] or 'msd_mean' or 'x2_mean' with optional 'position_mean'."
)
# Optional attrs
for k in (
"method",
"n_realizations",
"sigma_phi_deg",
"seed_base",
"state_format",
"J_0_eV",
"sigma_J_eV",
"k_parallel",
"sigma_sites",
"eps_0_eV",
):
if k in f.attrs:
meta[k] = f.attrs[k]
if dx.shape != t_ps.shape:
raise ValueError(f"{h5_path} Δx shape {dx.shape} and t_ps shape {t_ps.shape} mismatch.")
return t_ps, dx, meta
[docs]
def _select_x(t_ps: np.ndarray, cfg_ps) -> np.ndarray:
"""Return boolean mask for x selection by index or by time value (ps)."""
if hasattr(cfg_ps, "x_index_range") and cfg_ps.x_index_range:
i0, i1 = int(cfg_ps.x_index_range[0]), int(cfg_ps.x_index_range[1])
sel = np.zeros_like(t_ps, dtype=bool)
sel[max(0, i0): min(len(t_ps), i1 + 1)] = True
return sel
if hasattr(cfg_ps, "x_range") and cfg_ps.x_range:
t_plot = _to_plot_time(t_ps, cfg_ps)
xmin, xmax = float(cfg_ps.x_range[0]), float(cfg_ps.x_range[1])
return (t_plot >= xmin) & (t_plot <= xmax)
if hasattr(cfg_ps, "x_range_ps") and cfg_ps.x_range_ps:
xmin, xmax = float(cfg_ps.x_range_ps[0]), float(cfg_ps.x_range_ps[1])
return (t_ps >= xmin) & (t_ps <= xmax)
if hasattr(cfg_ps, "x_range_fs") and cfg_ps.x_range_fs:
t_fs = t_ps * 1.0e3
xmin, xmax = float(cfg_ps.x_range_fs[0]), float(cfg_ps.x_range_fs[1])
return (t_fs >= xmin) & (t_fs <= xmax)
return np.ones_like(t_ps, dtype=bool)
HYDRA_CONFIG_PATH: str = prepare_hydra_config_path("plots", __file__)
@hydra.main(config_path=HYDRA_CONFIG_PATH, config_name="sqrt_msd", version_base=None)
def main(cfg: DictConfig) -> None:
outdir = Path(HydraConfig.get().runtime.output_dir)
setup_loggers_hydra_aware()
ps = cfg.plot_settings
fig, ax = plt.subplots(figsize=(ps.figsize[0], ps.figsize[1]) if getattr(ps, "figsize", None) else (7, 5))
loaded_curves = []
for curve in cfg.curves:
path = _resolve_input_path(curve)
t_ps, dx_nm, meta = _load_dx_and_time(path)
sel = _select_x(t_ps, ps)
x = _to_plot_time(t_ps[sel], ps) * getattr(ps, "x_scale_factor", 1.0)
y = dx_nm[sel]
# style
style = getattr(curve, "style", "-")
lw = getattr(curve, "lw", ps.get("lw", 1.5))
label = getattr(curve, "label", path.stem)
color = getattr(curve, "color", None)
ax.plot(x, y, style, lw=lw, label=label, color=color)
loaded_curves.append({"t_ps": t_ps, "sel": sel, "meta": meta})
logger.info(f"Plotted {label} from {path.name} (source={meta.get('source','?')})")
if bool(getattr(ps, "enable_analytical_curves", True)):
for curve in getattr(cfg, "analytical_curves", []):
if not loaded_curves:
raise ValueError("analytical_curves requires at least one numerical curve to define time grid.")
from_curve_index = int(getattr(curve, "from_curve_index", 0))
if from_curve_index < 0 or from_curve_index >= len(loaded_curves):
raise ValueError(
f"analytical from_curve_index={from_curve_index} out of range [0, {len(loaded_curves)-1}]."
)
ref = loaded_curves[from_curve_index]
t_ps_ref = np.asarray(ref["t_ps"])
sel = np.asarray(ref["sel"], dtype=bool)
meta = dict(ref["meta"])
params = dict(getattr(curve, "params", {}) or {})
for key in ("J_0_eV", "sigma_J_eV", "k_parallel", "sigma_sites"):
if key not in params and key in meta:
params[key] = meta[key]
model = str(getattr(curve, "model", "gaussian_wave"))
t_fs = t_ps_ref[sel] * 1.0e3
y = _nn_rmsd_analytical(t_fs, model, params)
x = _to_plot_time(t_ps_ref[sel], ps) * getattr(ps, "x_scale_factor", 1.0)
style = getattr(curve, "style", "-")
lw = getattr(curve, "lw", ps.get("lw", 1.5))
label = getattr(curve, "label", f"Analytical RMSD ({model})")
color = getattr(curve, "color", None)
ax.plot(x, y, style, lw=lw, label=label, color=color)
logger.info(
f"Plotted analytical curve: {label} (model={model}, from_curve_index={from_curve_index})"
)
# labels and title
ax.set_xlabel(ps.xlabel)
ax.set_ylabel(ps.ylabel)
if getattr(ps, "title", None):
ax.set_title(ps.title)
# scales
if getattr(ps, "xscale", None): ax.set_xscale(ps.xscale)
if getattr(ps, "yscale", None): ax.set_yscale(ps.yscale)
# limits
if getattr(ps, "xlim", None): ax.set_xlim(ps.xlim[0], ps.xlim[1])
if getattr(ps, "ylim", None): ax.set_ylim(ps.ylim[0], ps.ylim[1])
if getattr(ps, "grid", True):
ax.grid(True, which="both", ls="--", alpha=0.5)
if getattr(ps, "legend", True):
ax.legend()
if getattr(ps, "tight_layout", True):
plt.tight_layout()
if getattr(ps, "save_plot", True):
name = getattr(ps, "filename", "sqrt_msd.png")
figpath = outdir / name
fig.savefig(figpath, dpi=getattr(ps, "dpi", 300), bbox_inches="tight")
logger.success(f"Saved plot → {figpath}")
if getattr(ps, "show", False):
plt.show()
plt.close(fig)
if __name__ == "__main__":
main()