♻️ crazy refactor

This commit is contained in:
2026-03-11 22:52:01 +01:00
parent 35223b3560
commit 4115447022
34 changed files with 4255 additions and 102 deletions

287
src/sysid/visualize.py Normal file
View File

@@ -0,0 +1,287 @@
"""Visualise system identification results — real vs simulated trajectories.
Loads a recording and runs simulation with both the original and tuned
parameters, then plots a 4-panel comparison (motor angle, motor vel,
pendulum angle, pendulum vel) over time.
Usage:
python -m src.sysid.visualize \
--robot-path assets/rotary_cartpole \
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
# Also compare with tuned parameters:
python -m src.sysid.visualize \
--robot-path assets/rotary_cartpole \
--recording <file>.npz \
--result assets/rotary_cartpole/sysid_result.json
"""
from __future__ import annotations
import argparse
import json
import math
from pathlib import Path
import numpy as np
import structlog
log = structlog.get_logger()
def visualize(
robot_path: str | Path,
recording_path: str | Path,
result_path: str | Path | None = None,
sim_dt: float = 0.002,
substeps: int = 10,
window_duration: float = 0.5,
save_path: str | Path | None = None,
show: bool = True,
) -> None:
"""Generate comparison plot.
Parameters
----------
robot_path : robot asset directory
recording_path : .npz file from capture
result_path : sysid_result.json with best_params (optional)
sim_dt / substeps : physics settings for rollout
window_duration : shooting window length (s); 0 = open-loop
save_path : if provided, save figure to this path (PNG, PDF, …)
show : if True, display interactive matplotlib window
"""
import matplotlib.pyplot as plt
from src.sysid.rollout import (
ROTARY_CARTPOLE_PARAMS,
defaults_vector,
params_to_dict,
rollout,
windowed_rollout,
)
robot_path = Path(robot_path).resolve()
recording = dict(np.load(recording_path))
t = recording["time"]
actions = recording["action"]
# ── Simulate with default parameters ─────────────────────────
default_params = params_to_dict(
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
)
log.info("simulating_default_params", windowed=window_duration > 0)
if window_duration > 0:
sim_default = windowed_rollout(
robot_path=robot_path,
params=default_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_default = rollout(
robot_path=robot_path,
params=default_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
# ── Simulate with tuned parameters (if available) ────────────
sim_tuned = None
tuned_cost = None
if result_path is not None:
result_path = Path(result_path)
if result_path.exists():
result = json.loads(result_path.read_text())
tuned_params = result.get("best_params", {})
tuned_cost = result.get("best_cost")
log.info("simulating_tuned_params", cost=tuned_cost)
if window_duration > 0:
sim_tuned = windowed_rollout(
robot_path=robot_path,
params=tuned_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_tuned = rollout(
robot_path=robot_path,
params=tuned_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
else:
log.warning("result_file_not_found", path=str(result_path))
else:
# Auto-detect sysid_result.json in robot_path.
auto_result = robot_path / "sysid_result.json"
if auto_result.exists():
result = json.loads(auto_result.read_text())
tuned_params = result.get("best_params", {})
tuned_cost = result.get("best_cost")
log.info("auto_detected_tuned_params", cost=tuned_cost)
if window_duration > 0:
sim_tuned = windowed_rollout(
robot_path=robot_path,
params=tuned_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_tuned = rollout(
robot_path=robot_path,
params=tuned_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
# ── Plot ─────────────────────────────────────────────────────
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
channels = [
("motor_angle", "Motor Angle (rad)", True),
("motor_vel", "Motor Velocity (rad/s)", False),
("pendulum_angle", "Pendulum Angle (rad)", True),
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
]
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
real = recording[key]
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
ax.plot(
t,
sim_default[key],
"--",
color="#d62728",
linewidth=1.0,
alpha=0.7,
label="Sim (original)",
)
if sim_tuned is not None:
ax.plot(
t,
sim_tuned[key],
"--",
color="#2ca02c",
linewidth=1.0,
alpha=0.7,
label="Sim (tuned)",
)
ax.set_ylabel(ylabel)
ax.legend(loc="upper right", fontsize=8)
ax.grid(True, alpha=0.3)
# Action plot (bottom panel).
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
axes[4].set_ylabel("Action (norm)")
axes[4].set_xlabel("Time (s)")
axes[4].grid(True, alpha=0.3)
axes[4].set_ylim(-1.1, 1.1)
# Title with cost info.
title = "System Identification — Real vs Simulated Trajectories"
if tuned_cost is not None:
# Compute original cost for comparison.
from src.sysid.optimize import cost_function
orig_cost = cost_function(
defaults_vector(ROTARY_CARTPOLE_PARAMS),
recording,
robot_path,
ROTARY_CARTPOLE_PARAMS,
sim_dt=sim_dt,
substeps=substeps,
window_duration=window_duration,
)
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
title += f" ({improvement:+.1f}%)"
fig.suptitle(title, fontsize=12)
plt.tight_layout()
if save_path:
save_path = Path(save_path)
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
log.info("figure_saved", path=str(save_path))
if show:
plt.show()
else:
plt.close(fig)
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Visualise system identification results."
)
parser.add_argument(
"--robot-path",
type=str,
default="assets/rotary_cartpole",
)
parser.add_argument(
"--recording",
type=str,
required=True,
help="Path to .npz recording file",
)
parser.add_argument(
"--result",
type=str,
default=None,
help="Path to sysid_result.json (auto-detected if omitted)",
)
parser.add_argument("--sim-dt", type=float, default=0.002)
parser.add_argument("--substeps", type=int, default=10)
parser.add_argument(
"--window-duration",
type=float,
default=0.5,
help="Shooting window length in seconds (0 = open-loop)",
)
parser.add_argument(
"--save",
type=str,
default=None,
help="Save figure to this path (PNG, PDF, …)",
)
parser.add_argument(
"--no-show",
action="store_true",
help="Don't show interactive window (useful for CI)",
)
args = parser.parse_args()
visualize(
robot_path=args.robot_path,
recording_path=args.recording,
result_path=args.result,
sim_dt=args.sim_dt,
substeps=args.substeps,
window_duration=args.window_duration,
save_path=args.save,
show=not args.no_show,
)
if __name__ == "__main__":
main()