♻️ crazy refactor
This commit is contained in:
287
src/sysid/visualize.py
Normal file
287
src/sysid/visualize.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Visualise system identification results — real vs simulated trajectories.
|
||||
|
||||
Loads a recording and runs simulation with both the original and tuned
|
||||
parameters, then plots a 4-panel comparison (motor angle, motor vel,
|
||||
pendulum angle, pendulum vel) over time.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||
|
||||
# Also compare with tuned parameters:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording <file>.npz \
|
||||
--result assets/rotary_cartpole/sysid_result.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def visualize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
result_path: str | Path | None = None,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
window_duration: float = 0.5,
|
||||
save_path: str | Path | None = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""Generate comparison plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : robot asset directory
|
||||
recording_path : .npz file from capture
|
||||
result_path : sysid_result.json with best_params (optional)
|
||||
sim_dt / substeps : physics settings for rollout
|
||||
window_duration : shooting window length (s); 0 = open-loop
|
||||
save_path : if provided, save figure to this path (PNG, PDF, …)
|
||||
show : if True, display interactive matplotlib window
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from src.sysid.rollout import (
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
rollout,
|
||||
windowed_rollout,
|
||||
)
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording = dict(np.load(recording_path))
|
||||
|
||||
t = recording["time"]
|
||||
actions = recording["action"]
|
||||
|
||||
# ── Simulate with default parameters ─────────────────────────
|
||||
default_params = params_to_dict(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
|
||||
)
|
||||
log.info("simulating_default_params", windowed=window_duration > 0)
|
||||
if window_duration > 0:
|
||||
sim_default = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_default = rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
# ── Simulate with tuned parameters (if available) ────────────
|
||||
sim_tuned = None
|
||||
tuned_cost = None
|
||||
if result_path is not None:
|
||||
result_path = Path(result_path)
|
||||
if result_path.exists():
|
||||
result = json.loads(result_path.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("simulating_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
log.warning("result_file_not_found", path=str(result_path))
|
||||
else:
|
||||
# Auto-detect sysid_result.json in robot_path.
|
||||
auto_result = robot_path / "sysid_result.json"
|
||||
if auto_result.exists():
|
||||
result = json.loads(auto_result.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("auto_detected_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
# ── Plot ─────────────────────────────────────────────────────
|
||||
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
|
||||
|
||||
channels = [
|
||||
("motor_angle", "Motor Angle (rad)", True),
|
||||
("motor_vel", "Motor Velocity (rad/s)", False),
|
||||
("pendulum_angle", "Pendulum Angle (rad)", True),
|
||||
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
|
||||
]
|
||||
|
||||
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
|
||||
real = recording[key]
|
||||
|
||||
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
|
||||
ax.plot(
|
||||
t,
|
||||
sim_default[key],
|
||||
"--",
|
||||
color="#d62728",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (original)",
|
||||
)
|
||||
if sim_tuned is not None:
|
||||
ax.plot(
|
||||
t,
|
||||
sim_tuned[key],
|
||||
"--",
|
||||
color="#2ca02c",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (tuned)",
|
||||
)
|
||||
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.legend(loc="upper right", fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Action plot (bottom panel).
|
||||
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
|
||||
axes[4].set_ylabel("Action (norm)")
|
||||
axes[4].set_xlabel("Time (s)")
|
||||
axes[4].grid(True, alpha=0.3)
|
||||
axes[4].set_ylim(-1.1, 1.1)
|
||||
|
||||
# Title with cost info.
|
||||
title = "System Identification — Real vs Simulated Trajectories"
|
||||
if tuned_cost is not None:
|
||||
# Compute original cost for comparison.
|
||||
from src.sysid.optimize import cost_function
|
||||
|
||||
orig_cost = cost_function(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS),
|
||||
recording,
|
||||
robot_path,
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
window_duration=window_duration,
|
||||
)
|
||||
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||
title += f" ({improvement:+.1f}%)"
|
||||
|
||||
fig.suptitle(title, fontsize=12)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
save_path = Path(save_path)
|
||||
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
||||
log.info("figure_saved", path=str(save_path))
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualise system identification results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recording",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to .npz recording file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to sysid_result.json (auto-detected if omitted)",
|
||||
)
|
||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||
parser.add_argument("--substeps", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shooting window length in seconds (0 = open-loop)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Save figure to this path (PNG, PDF, …)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-show",
|
||||
action="store_true",
|
||||
help="Don't show interactive window (useful for CI)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize(
|
||||
robot_path=args.robot_path,
|
||||
recording_path=args.recording,
|
||||
result_path=args.result,
|
||||
sim_dt=args.sim_dt,
|
||||
substeps=args.substeps,
|
||||
window_duration=args.window_duration,
|
||||
save_path=args.save,
|
||||
show=not args.no_show,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user