from __future__ import annotations
from pathlib import Path
import numpy as np
import h5py
import matplotlib.pyplot as plt
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.core.hydra_config import HydraConfig
from loguru import logger
from mqed.utils.file_utils import _resolve_input_path
from mqed.utils.logging_utils import setup_loggers_hydra_aware
from mqed.utils.hydra_local import prepare_hydra_config_path
def _to_plot_time(t_ps: np.ndarray, cfg_ps) -> np.ndarray:
unit = str(getattr(cfg_ps, "time_unit", "ps")).lower()
if unit == "ps":
return t_ps
if unit == "fs":
return t_ps * 1.0e3
if unit == "s":
return t_ps * 1.0e-12
raise ValueError(f"Unsupported plot_settings.time_unit='{unit}'. Use 'fs', 'ps', or 's'.")
def _msd_analytical_local(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, sigma_j_eV: float) -> np.ndarray:
j_eff_sq = j_0_eV ** 2 + sigma_j_eV ** 2
return 2.0 * a ** 2 * j_eff_sq * t_fs ** 2 / (hbar_eV_fs ** 2)
def _x_square_analytical_gaussian(
t_fs: np.ndarray,
a: float,
hbar_eV_fs: float,
j_0_eV: float,
sigma_j_eV: float,
k_parallel: float,
sigma_sites: float,
) -> np.ndarray:
term1 = sigma_sites ** 2 / 2.0
term2 = (4.0 * j_0_eV ** 2 / hbar_eV_fs ** 2) * np.sin(k_parallel * a) ** 2
term3 = 2.0 * sigma_j_eV ** 2 / hbar_eV_fs ** 2
return a ** 2 * (term1 + (term2 + term3) * t_fs ** 2)
def _position_analytical_gaussian(t_fs: np.ndarray, a: float, hbar_eV_fs: float, j_0_eV: float, k_parallel: float) -> np.ndarray:
velocity_prefactor = (-2.0 * j_0_eV / hbar_eV_fs) * np.sin(k_parallel * a)
return a * velocity_prefactor * t_fs
[docs]
def _nn_msd_analytical(t_fs: np.ndarray, model: str, params: dict) -> np.ndarray:
"""Analytical MSD for nearest-neighbour chain models.
MSD = <(x-x0)^2> (second moment of displacement from initial site).
Note: this is NOT the variance <(x-x0)^2> - <x-x0>^2.
"""
a = float(params.get("a", 1.0))
hbar_eV_fs = float(params.get("hbar_eV_fs", 0.6582119514))
j_0_eV = float(params["J_0_eV"])
sigma_j_eV = float(params["sigma_J_eV"])
if model == "local_excitation":
return _msd_analytical_local(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV)
if model == "gaussian_wave":
k_parallel = float(params["k_parallel"])
sigma_sites = float(params["sigma_sites"])
# MSD = <(x-x0)^2> = x2 (the full second moment, not x2 - <x>^2)
x2 = _x_square_analytical_gaussian(t_fs, a, hbar_eV_fs, j_0_eV, sigma_j_eV, k_parallel, sigma_sites)
return x2
raise ValueError(f"Unsupported analytical model '{model}'. Use 'local_excitation' or 'gaussian_wave'.")
[docs]
def _load_dx_and_time(h5_path: Path) -> tuple[np.ndarray, np.ndarray, dict]:
"""
Returns:
t_ps: (T,)
dx_nm: (T,) (mean if available; otherwise single-run Δx)
meta: dict with info about what we loaded
Supports:
- datasets: 'dx_mean_nm' (preferred), 'dx_nm', or expectations: X_shift, X_shift2 (compute Δx)
"""
logger.info(f"Loading Δx data from {h5_path}")
meta = {}
with h5py.File(str(h5_path), "r") as f:
# time
t_ps_ds = f.get("t_ps")
if not isinstance(t_ps_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has no 't_ps' dataset.")
t_ps = np.asarray(t_ps_ds[...]).ravel()
# breakpoint()
msd_ds = f.get("msd_nm2")
ex_group = f.get("expectations")
dx_mean_ds = f.get("dx_mean_nm")
# 1) direct MSD dataset?
if isinstance(msd_ds, h5py.Dataset):
msd = np.asarray(msd_ds[...]).ravel()
meta["source"] = "msd_nm2"
# 2) expectations group
elif isinstance(ex_group, h5py.Group) and "x2_mean" in ex_group and "position_mean" in ex_group:
x2_ds = ex_group.get("x2_mean")
x_ds = ex_group.get("position_mean")
if not isinstance(x2_ds, h5py.Dataset) or not isinstance(x_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/x2_mean or position_mean dataset.")
x2 = np.asarray(x2_ds[...]).ravel()
# MSD = <(x-x0)^2> = x2 (second moment of displacement).
# Note: x2 - <x>^2 would be the variance, not the MSD.
msd = x2
meta["source"] = "expectations/x2_mean,position_mean"
elif isinstance(ex_group, h5py.Group) and "X_shift2" in ex_group and "X_shift" in ex_group:
x2_ds = ex_group.get("X_shift2")
x_ds = ex_group.get("X_shift")
if not isinstance(x2_ds, h5py.Dataset) or not isinstance(x_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/X_shift2 or X_shift dataset.")
x2 = np.asarray(x2_ds[...]).ravel()
# MSD = <(x-x0)^2> = x2 (second moment of displacement).
msd = x2
meta["source"] = "expectations/X_shift2,X_shift"
elif isinstance(ex_group, h5py.Group) and "msd_mean" in ex_group:
msd_mean_ds = ex_group.get("msd_mean")
if not isinstance(msd_mean_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/msd_mean dataset.")
msd = np.asarray(msd_mean_ds[...]).ravel()
meta["source"] = "expectations/msd_mean"
elif isinstance(ex_group, h5py.Group) and "x2_mean" in ex_group:
x2_ds = ex_group.get("x2_mean")
if not isinstance(x2_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/x2_mean dataset.")
msd = np.asarray(x2_ds[...]).ravel()
meta["source"] = "expectations/x2_mean"
elif isinstance(ex_group, h5py.Group) and "X_shift2" in ex_group:
x2_ds = ex_group.get("X_shift2")
if not isinstance(x2_ds, h5py.Dataset):
raise ValueError(f"{h5_path} has invalid expectations/X_shift2 dataset.")
msd = np.asarray(x2_ds[...]).ravel()
meta["source"] = "expectations/X_shift2"
# 3) last resort: square of sqrt-MSD (if that file only saved dx)
elif isinstance(dx_mean_ds, h5py.Dataset):
dx = np.asarray(dx_mean_ds[...]).ravel()
msd = dx**2
meta["source"] = "dx_mean_nm**2"
else:
raise ValueError(
f"{h5_path} does not contain 'msd_nm2', "
"'expectations/msd_mean', {'x2_mean','position_mean'}, {'X_shift2','X_shift'}, "
"'expectations/x2_mean', 'expectations/X_shift2', or 'dx_mean_nm'."
)
# carry over a few helpful attributes if present
for k in (
"method",
"mode",
"n_realizations",
"sigma_phi_deg",
"seed_base",
"J_0_eV",
"sigma_J_eV",
"k_parallel",
"sigma_sites",
"eps_0_eV",
):
if k in f.attrs:
meta[k] = f.attrs[k]
if msd.shape != t_ps.shape:
raise ValueError(f"{h5_path} msd shape {msd.shape} and t_ps shape {t_ps.shape} mismatch.")
return t_ps, msd, meta
[docs]
def _select_x(t_ps: np.ndarray, cfg_ps) -> np.ndarray:
"""Return boolean mask for x selection by index or by time value (ps)."""
if hasattr(cfg_ps, "x_index_range") and cfg_ps.x_index_range:
i0, i1 = int(cfg_ps.x_index_range[0]), int(cfg_ps.x_index_range[1])
sel = np.zeros_like(t_ps, dtype=bool)
sel[max(0, i0): min(len(t_ps), i1 + 1)] = True
return sel
if hasattr(cfg_ps, "x_range") and cfg_ps.x_range:
t_plot = _to_plot_time(t_ps, cfg_ps)
xmin, xmax = float(cfg_ps.x_range[0]), float(cfg_ps.x_range[1])
return (t_plot >= xmin) & (t_plot <= xmax)
if hasattr(cfg_ps, "x_range_ps") and cfg_ps.x_range_ps:
xmin, xmax = float(cfg_ps.x_range_ps[0]), float(cfg_ps.x_range_ps[1])
return (t_ps >= xmin) & (t_ps <= xmax)
if hasattr(cfg_ps, "x_range_fs") and cfg_ps.x_range_fs:
t_fs = t_ps * 1.0e3
xmin, xmax = float(cfg_ps.x_range_fs[0]), float(cfg_ps.x_range_fs[1])
return (t_fs >= xmin) & (t_fs <= xmax)
return np.ones_like(t_ps, dtype=bool)
HYDRA_CONFIG_PATH: str = prepare_hydra_config_path("plots", __file__)
@hydra.main(config_path=HYDRA_CONFIG_PATH, config_name="msd", version_base=None)
def main(cfg: DictConfig) -> None:
outdir = Path(HydraConfig.get().runtime.output_dir)
setup_loggers_hydra_aware()
ps = cfg.plot_settings
fig, ax = plt.subplots(figsize=(ps.figsize[0], ps.figsize[1]) if getattr(ps, "figsize", None) else (8, 6))
# set global font sizes
font = getattr(ps, "font", None)
# optional: set global family (affects everything)
if font and getattr(font, "family", None):
plt.rcParams["font.family"] = str(font.family)
loaded_curves = []
for curve in cfg.curves:
path = _resolve_input_path(curve)
t_ps, msd, meta = _load_dx_and_time(path)
sel = _select_x(t_ps, ps)
x = _to_plot_time(t_ps[sel], ps) * getattr(ps, "x_scale_factor", 1.0)
y = msd[sel]
# style
linestyle = getattr(curve, "linestyle", getattr(curve, "style", "-"))
if isinstance(linestyle, str) and linestyle.lower() == "none":
linestyle = "None"
lw = getattr(curve, "lw", ps.get("lw", 1.5))
label = getattr(curve, "label", path.stem)
color = getattr(curve, "color", None)
marker = getattr(curve, "marker", None)
markersize = getattr(curve, "markersize", None)
markerfacecolor = getattr(curve, "markerfacecolor", None)
markeredgecolor = getattr(curve, "markeredgecolor", color)
markeredgewidth = getattr(curve, "markeredgewidth", None)
markevery = getattr(curve, "markevery", None)
alpha = getattr(curve, "alpha", None)
zorder = getattr(curve, "zorder", None)
plot_kwargs = {
"lw": lw,
"label": label,
"color": color,
"linestyle": linestyle,
}
if marker is not None:
plot_kwargs["marker"] = marker
if markersize is not None:
plot_kwargs["markersize"] = markersize
if markerfacecolor is not None:
plot_kwargs["markerfacecolor"] = markerfacecolor
if markeredgecolor is not None:
plot_kwargs["markeredgecolor"] = markeredgecolor
if markeredgewidth is not None:
plot_kwargs["markeredgewidth"] = markeredgewidth
if markevery is not None:
plot_kwargs["markevery"] = markevery
if alpha is not None:
plot_kwargs["alpha"] = alpha
if zorder is not None:
plot_kwargs["zorder"] = zorder
ax.plot(x, y, **plot_kwargs)
loaded_curves.append({"t_ps": t_ps, "sel": sel, "meta": meta})
logger.info(f"Plotted {label} from {path.name} (source={meta.get('source','?')})")
if bool(getattr(ps, "enable_analytical_curves", True)):
for curve in getattr(cfg, "analytical_curves", []):
if not loaded_curves:
raise ValueError("analytical_curves requires at least one numerical curve to define time grid.")
from_curve_index = int(getattr(curve, "from_curve_index", 0))
if from_curve_index < 0 or from_curve_index >= len(loaded_curves):
raise ValueError(
f"analytical from_curve_index={from_curve_index} out of range [0, {len(loaded_curves)-1}]."
)
ref = loaded_curves[from_curve_index]
t_ps_ref = np.asarray(ref["t_ps"])
sel = np.asarray(ref["sel"], dtype=bool)
meta = dict(ref["meta"])
params = dict(getattr(curve, "params", {}) or {})
for key in ("J_0_eV", "sigma_J_eV", "k_parallel", "sigma_sites"):
if key not in params and key in meta:
params[key] = meta[key]
model = str(getattr(curve, "model", "gaussian_wave"))
t_fs = t_ps_ref[sel] * 1.0e3
y = _nn_msd_analytical(t_fs, model, params)
x = _to_plot_time(t_ps_ref[sel], ps) * getattr(ps, "x_scale_factor", 1.0)
style = getattr(curve, "style", "-")
lw = getattr(curve, "lw", ps.get("lw", 1.5))
label = getattr(curve, "label", f"Analytical ({model})")
color = getattr(curve, "color", None)
ax.plot(x, y, style, lw=lw, label=label, color=color)
logger.info(
f"Plotted analytical curve: {label} (model={model}, from_curve_index={from_curve_index})"
)
# labels and title
if font:
labelsize = int(getattr(font, "labelsize", 12))
titlesize = int(getattr(font, "titlesize", 12))
ticksize = int(getattr(font, "ticksize", 12))
legendsize = int(getattr(font, "legendsize", 12))
labelweight = str(getattr(font, "labelweight", "normal"))
legendweight = str(getattr(font, "legendweight", "normal"))
else:
labelsize = titlesize = 12
ticksize = 12
legendsize = 12
labelweight = "normal"
legendweight = "normal"
ax.set_xlabel(ps.xlabel, fontsize=labelsize, fontweight=labelweight)
ax.set_ylabel(ps.ylabel, fontsize=labelsize, fontweight=labelweight)
if getattr(ps, "title", None):
ax.set_title(ps.title, fontsize=titlesize, fontweight=labelweight)
# ticks
ax.tick_params(axis="both", which="both", labelsize=ticksize)
# NEW: bold tick labels if requested (fallback to labelweight if tickweight not set)
tickweight = str(getattr(font, "tickweight", labelweight)) if font else labelweight
for ticklabel in ax.get_xticklabels() + ax.get_yticklabels():
ticklabel.set_fontweight(tickweight)
# legend
if getattr(ps, "legend", True):
leg = ax.legend(fontsize=legendsize)
# make legend text bold if requested
for txt in leg.get_texts():
txt.set_fontweight(legendweight)
# scales
if getattr(ps, "xscale", None): ax.set_xscale(ps.xscale)
if getattr(ps, "yscale", None): ax.set_yscale(ps.yscale)
# limits
if getattr(ps, "xlim", None): ax.set_xlim(ps.xlim[0], ps.xlim[1])
if getattr(ps, "ylim", None): ax.set_ylim(ps.ylim[0], ps.ylim[1])
ysc = getattr(ps, "y_sci", None)
if ysc and getattr(ysc, "enabled", False):
logger.info("Use scientific visualization on y axis.")
if getattr(ysc, "style", "sci") == "sci":
ax.ticklabel_format(
axis="y",
style="sci",
scilimits=(
int(getattr(ysc, "scilimits", (-2, 2))[0]),
int(getattr(ysc, "scilimits", (-2, 2))[1]),
),
useMathText=bool(getattr(ysc, "use_math_text", True)),
)
off = ax.yaxis.get_offset_text()
off.set_fontsize(int(getattr(ysc,"offset_text_size",ticksize)))
else:
ax.ticklabel_format(axis="y", style="plain")
if getattr(ps, "grid", True):
ax.grid(True, which="both", ls="--", alpha=0.5)
if getattr(ps, "tight_layout", True):
plt.tight_layout()
if getattr(ps, "save_plot", True):
name = getattr(ps, "filename", "sqrt_msd.png")
figpath = outdir / name
fig.savefig(figpath, dpi=getattr(ps, "dpi", 300), bbox_inches="tight")
logger.success(f"Saved plot → {figpath}")
if getattr(ps, "show", False):
plt.show()
plt.close(fig)
if __name__ == "__main__":
main()