✨ clean up lot of stuff
This commit is contained in:
@@ -20,7 +20,6 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -29,6 +28,30 @@ import structlog
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def _run_sim(
|
||||
robot_path: Path,
|
||||
params: dict[str, float],
|
||||
recording: dict[str, np.ndarray],
|
||||
window_duration: float,
|
||||
sim_dt: float,
|
||||
substeps: int,
|
||||
motor_params: dict[str, float],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Run windowed or open-loop rollout depending on window_duration."""
|
||||
from src.sysid.rollout import rollout, windowed_rollout
|
||||
|
||||
if window_duration > 0:
|
||||
return windowed_rollout(
|
||||
robot_path=robot_path, params=params, recording=recording,
|
||||
window_duration=window_duration, sim_dt=sim_dt,
|
||||
substeps=substeps, motor_params=motor_params,
|
||||
)
|
||||
return rollout(
|
||||
robot_path=robot_path, params=params, actions=recording["action"],
|
||||
substeps=substeps, motor_params=motor_params,
|
||||
)
|
||||
|
||||
|
||||
def visualize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
@@ -39,31 +62,27 @@ def visualize(
|
||||
save_path: str | Path | None = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""Generate comparison plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : robot asset directory
|
||||
recording_path : .npz file from capture
|
||||
result_path : sysid_result.json with best_params (optional)
|
||||
sim_dt / substeps : physics settings for rollout
|
||||
window_duration : shooting window length (s); 0 = open-loop
|
||||
save_path : if provided, save figure to this path (PNG, PDF, …)
|
||||
show : if True, display interactive matplotlib window
|
||||
"""
|
||||
"""Generate comparison plot."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from src.sysid.rollout import (
|
||||
LOCKED_MOTOR_PARAMS,
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
rollout,
|
||||
windowed_rollout,
|
||||
)
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording = dict(np.load(recording_path))
|
||||
|
||||
motor_params = LOCKED_MOTOR_PARAMS
|
||||
|
||||
sim_kwargs = dict(
|
||||
robot_path=robot_path, recording=recording,
|
||||
window_duration=window_duration, sim_dt=sim_dt,
|
||||
substeps=substeps, motor_params=motor_params,
|
||||
)
|
||||
|
||||
t = recording["time"]
|
||||
actions = recording["action"]
|
||||
|
||||
@@ -72,26 +91,15 @@ def visualize(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
|
||||
)
|
||||
log.info("simulating_default_params", windowed=window_duration > 0)
|
||||
if window_duration > 0:
|
||||
sim_default = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_default = rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
sim_default = _run_sim(params=default_params, **sim_kwargs)
|
||||
|
||||
# ── Simulate with tuned parameters (if available) ────────────
|
||||
# Resolve result path (explicit or auto-detect).
|
||||
if result_path is None:
|
||||
auto = robot_path / "sysid_result.json"
|
||||
if auto.exists():
|
||||
result_path = auto
|
||||
|
||||
sim_tuned = None
|
||||
tuned_cost = None
|
||||
if result_path is not None:
|
||||
@@ -101,64 +109,21 @@ def visualize(
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("simulating_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
sim_tuned = _run_sim(params=tuned_params, **sim_kwargs)
|
||||
else:
|
||||
log.warning("result_file_not_found", path=str(result_path))
|
||||
else:
|
||||
# Auto-detect sysid_result.json in robot_path.
|
||||
auto_result = robot_path / "sysid_result.json"
|
||||
if auto_result.exists():
|
||||
result = json.loads(auto_result.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("auto_detected_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
# ── Plot ─────────────────────────────────────────────────────
|
||||
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
|
||||
|
||||
channels = [
|
||||
("motor_angle", "Motor Angle (rad)", True),
|
||||
("motor_vel", "Motor Velocity (rad/s)", False),
|
||||
("pendulum_angle", "Pendulum Angle (rad)", True),
|
||||
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
|
||||
("motor_angle", "Motor Angle (rad)"),
|
||||
("motor_vel", "Motor Velocity (rad/s)"),
|
||||
("pendulum_angle", "Pendulum Angle (rad)"),
|
||||
("pendulum_vel", "Pendulum Velocity (rad/s)"),
|
||||
]
|
||||
|
||||
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
|
||||
for ax, (key, ylabel) in zip(axes[:4], channels):
|
||||
real = recording[key]
|
||||
|
||||
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
|
||||
@@ -207,6 +172,7 @@ def visualize(
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
window_duration=window_duration,
|
||||
motor_params=motor_params,
|
||||
)
|
||||
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||
|
||||
Reference in New Issue
Block a user