clean up lot of stuff

This commit is contained in:
2026-03-22 15:49:13 +01:00
parent d3ed1c25ad
commit ca0e7b8b03
37 changed files with 3613 additions and 1223 deletions

View File

@@ -20,7 +20,6 @@ from __future__ import annotations
import argparse
import json
import math
from pathlib import Path
import numpy as np
@@ -29,6 +28,30 @@ import structlog
log = structlog.get_logger()
def _run_sim(
robot_path: Path,
params: dict[str, float],
recording: dict[str, np.ndarray],
window_duration: float,
sim_dt: float,
substeps: int,
motor_params: dict[str, float],
) -> dict[str, np.ndarray]:
"""Run windowed or open-loop rollout depending on window_duration."""
from src.sysid.rollout import rollout, windowed_rollout
if window_duration > 0:
return windowed_rollout(
robot_path=robot_path, params=params, recording=recording,
window_duration=window_duration, sim_dt=sim_dt,
substeps=substeps, motor_params=motor_params,
)
return rollout(
robot_path=robot_path, params=params, actions=recording["action"],
substeps=substeps, motor_params=motor_params,
)
def visualize(
robot_path: str | Path,
recording_path: str | Path,
@@ -39,31 +62,27 @@ def visualize(
save_path: str | Path | None = None,
show: bool = True,
) -> None:
"""Generate comparison plot.
Parameters
----------
robot_path : robot asset directory
recording_path : .npz file from capture
result_path : sysid_result.json with best_params (optional)
sim_dt / substeps : physics settings for rollout
window_duration : shooting window length (s); 0 = open-loop
save_path : if provided, save figure to this path (PNG, PDF, …)
show : if True, display interactive matplotlib window
"""
"""Generate comparison plot."""
import matplotlib.pyplot as plt
from src.sysid.rollout import (
LOCKED_MOTOR_PARAMS,
ROTARY_CARTPOLE_PARAMS,
defaults_vector,
params_to_dict,
rollout,
windowed_rollout,
)
robot_path = Path(robot_path).resolve()
recording = dict(np.load(recording_path))
motor_params = LOCKED_MOTOR_PARAMS
sim_kwargs = dict(
robot_path=robot_path, recording=recording,
window_duration=window_duration, sim_dt=sim_dt,
substeps=substeps, motor_params=motor_params,
)
t = recording["time"]
actions = recording["action"]
@@ -72,26 +91,15 @@ def visualize(
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
)
log.info("simulating_default_params", windowed=window_duration > 0)
if window_duration > 0:
sim_default = windowed_rollout(
robot_path=robot_path,
params=default_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_default = rollout(
robot_path=robot_path,
params=default_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
sim_default = _run_sim(params=default_params, **sim_kwargs)
# ── Simulate with tuned parameters (if available) ────────────
# Resolve result path (explicit or auto-detect).
if result_path is None:
auto = robot_path / "sysid_result.json"
if auto.exists():
result_path = auto
sim_tuned = None
tuned_cost = None
if result_path is not None:
@@ -101,64 +109,21 @@ def visualize(
tuned_params = result.get("best_params", {})
tuned_cost = result.get("best_cost")
log.info("simulating_tuned_params", cost=tuned_cost)
if window_duration > 0:
sim_tuned = windowed_rollout(
robot_path=robot_path,
params=tuned_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_tuned = rollout(
robot_path=robot_path,
params=tuned_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
sim_tuned = _run_sim(params=tuned_params, **sim_kwargs)
else:
log.warning("result_file_not_found", path=str(result_path))
else:
# Auto-detect sysid_result.json in robot_path.
auto_result = robot_path / "sysid_result.json"
if auto_result.exists():
result = json.loads(auto_result.read_text())
tuned_params = result.get("best_params", {})
tuned_cost = result.get("best_cost")
log.info("auto_detected_tuned_params", cost=tuned_cost)
if window_duration > 0:
sim_tuned = windowed_rollout(
robot_path=robot_path,
params=tuned_params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
)
else:
sim_tuned = rollout(
robot_path=robot_path,
params=tuned_params,
actions=actions,
timesteps=t,
sim_dt=sim_dt,
substeps=substeps,
)
# ── Plot ─────────────────────────────────────────────────────
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
channels = [
("motor_angle", "Motor Angle (rad)", True),
("motor_vel", "Motor Velocity (rad/s)", False),
("pendulum_angle", "Pendulum Angle (rad)", True),
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
("motor_angle", "Motor Angle (rad)"),
("motor_vel", "Motor Velocity (rad/s)"),
("pendulum_angle", "Pendulum Angle (rad)"),
("pendulum_vel", "Pendulum Velocity (rad/s)"),
]
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
for ax, (key, ylabel) in zip(axes[:4], channels):
real = recording[key]
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
@@ -207,6 +172,7 @@ def visualize(
sim_dt=sim_dt,
substeps=substeps,
window_duration=window_duration,
motor_params=motor_params,
)
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0