✨ clean up lot of stuff
This commit is contained in:
379
scripts/eval.py
Normal file
379
scripts/eval.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Evaluate a trained policy on real hardware (or in simulation).
|
||||
|
||||
Loads a checkpoint and runs the policy in a closed loop. For real
|
||||
hardware the serial runner talks to the ESP32; for sim it uses the
|
||||
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
|
||||
in both modes.
|
||||
|
||||
Usage (real hardware):
|
||||
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
|
||||
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
||||
|
||||
Usage (simulation):
|
||||
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
||||
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
||||
|
||||
Controls:
|
||||
Space — pause / resume policy (motor stops while paused)
|
||||
R — reset environment
|
||||
Esc — quit
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.models.mlp import SharedMLP
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_reset_flag = [False]
|
||||
_paused = [False]
|
||||
_quit_flag = [False]
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo viewer on key press."""
|
||||
if keycode == 32: # GLFW_KEY_SPACE
|
||||
_paused[0] = not _paused[0]
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
elif keycode == 256: # GLFW_KEY_ESCAPE
|
||||
_quit_flag[0] = True
|
||||
|
||||
|
||||
# ── checkpoint loading ───────────────────────────────────────────────
|
||||
|
||||
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
||||
"""Infer hidden layer sizes from a SharedMLP state dict."""
|
||||
sizes = []
|
||||
i = 0
|
||||
while f"net.{i}.weight" in state_dict:
|
||||
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
|
||||
i += 2 # skip activation layers (ELU)
|
||||
return tuple(sizes)
|
||||
|
||||
|
||||
def load_policy(
|
||||
checkpoint_path: str,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> tuple[SharedMLP, RunningStandardScaler]:
|
||||
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
||||
|
||||
Returns:
|
||||
(model, state_preprocessor) ready for inference.
|
||||
"""
|
||||
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||
|
||||
# Infer architecture from saved weights.
|
||||
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
||||
|
||||
# Reconstruct model.
|
||||
model = SharedMLP(
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
)
|
||||
model.load_state_dict(ckpt["policy"])
|
||||
model.eval()
|
||||
|
||||
# Reconstruct observation normalizer.
|
||||
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
|
||||
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
|
||||
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
|
||||
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
|
||||
# Freeze the normalizer — don't update stats during eval.
|
||||
state_preprocessor.training = False
|
||||
|
||||
logger.info(
|
||||
"checkpoint_loaded",
|
||||
path=checkpoint_path,
|
||||
hidden_sizes=hidden_sizes,
|
||||
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
|
||||
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
|
||||
)
|
||||
return model, state_preprocessor
|
||||
|
||||
|
||||
# ── action arrow overlay ─────────────────────────────────────────────
|
||||
|
||||
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
"""Draw an arrow showing applied torque direction."""
|
||||
if abs(action_val) < 0.01 or model.nu == 0:
|
||||
return
|
||||
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
body_id = model.jnt_bodyid[jnt_id]
|
||||
pos = data.xpos[body_id].copy()
|
||||
pos[2] += 0.02
|
||||
|
||||
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||
arrow_len = 0.08 * action_val
|
||||
direction = axis * np.sign(arrow_len)
|
||||
|
||||
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||
x = np.cross(up, z)
|
||||
x /= np.linalg.norm(x) + 1e-8
|
||||
y = np.cross(z, x)
|
||||
mat = np.column_stack([x, y, z]).flatten()
|
||||
|
||||
rgba = np.array(
|
||||
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||
pos=pos,
|
||||
mat=mat,
|
||||
rgba=rgba,
|
||||
)
|
||||
viewer.user_scn.ngeom += 1
|
||||
|
||||
|
||||
# ── main loops ───────────────────────────────────────────────────────
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco_single")
|
||||
|
||||
checkpoint_path = cfg.get("checkpoint", None)
|
||||
if checkpoint_path is None:
|
||||
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve relative paths against original working directory.
|
||||
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
|
||||
if not Path(checkpoint_path).exists():
|
||||
logger.error("checkpoint_not_found", path=checkpoint_path)
|
||||
sys.exit(1)
|
||||
|
||||
if runner_name == "serial":
|
||||
_eval_serial(cfg, env_name, checkpoint_path)
|
||||
else:
|
||||
_eval_sim(cfg, env_name, checkpoint_path)
|
||||
|
||||
|
||||
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"""Evaluate policy in MuJoCo simulation with viewer."""
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
device = runner.device
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, runner.observation_space, runner.action_space, device
|
||||
)
|
||||
|
||||
mj_model = runner._model
|
||||
mj_data = runner._data[0]
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode = 0
|
||||
episode_reward = 0.0
|
||||
|
||||
logger.info(
|
||||
"eval_started",
|
||||
env=env_name,
|
||||
mode="simulation",
|
||||
checkpoint=Path(checkpoint_path).name,
|
||||
controls="Space=pause, R=reset, Esc=quit",
|
||||
)
|
||||
|
||||
while viewer.is_running() and not _quit_flag[0]:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
logger.info("reset", episode=episode)
|
||||
|
||||
if _paused[0]:
|
||||
viewer.sync()
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
|
||||
# Policy inference
|
||||
with torch.no_grad():
|
||||
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
||||
action = model.act({"states": normalized_obs}, role="policy")[0]
|
||||
action = action.clamp(-1.0, 1.0)
|
||||
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
episode_reward += reward.item()
|
||||
step += 1
|
||||
|
||||
# Sync viewer
|
||||
mujoco.mj_forward(mj_model, mj_data)
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
||||
viewer.sync()
|
||||
|
||||
if step % 50 == 0:
|
||||
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
|
||||
for i in range(mj_model.njnt)}
|
||||
logger.debug(
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1), **joints,
|
||||
)
|
||||
|
||||
if terminated.any() or truncated.any():
|
||||
logger.info(
|
||||
"episode_done", episode=episode, steps=step,
|
||||
total_reward=round(episode_reward, 2),
|
||||
reason="terminated" if terminated.any() else "truncated",
|
||||
)
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
|
||||
time.sleep(dt_ctrl)
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
|
||||
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
|
||||
|
||||
device = serial_runner.device
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device
|
||||
)
|
||||
|
||||
# Set up digital-twin MuJoCo model for visualization.
|
||||
serial_runner._ensure_viz_model()
|
||||
mj_model = serial_runner._viz_model
|
||||
mj_data = serial_runner._viz_data
|
||||
|
||||
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
||||
obs, _ = serial_runner.reset()
|
||||
step = 0
|
||||
episode = 0
|
||||
episode_reward = 0.0
|
||||
|
||||
logger.info(
|
||||
"eval_started",
|
||||
env=env_name,
|
||||
mode="real hardware (serial)",
|
||||
port=serial_runner.config.port,
|
||||
checkpoint=Path(checkpoint_path).name,
|
||||
controls="Space=pause, R=reset, Esc=quit",
|
||||
)
|
||||
|
||||
while viewer.is_running() and not _quit_flag[0]:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
obs, _ = serial_runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
logger.info("reset", episode=episode)
|
||||
|
||||
if _paused[0]:
|
||||
serial_runner._send("M0") # safety: stop motor while paused
|
||||
serial_runner._sync_viz()
|
||||
viewer.sync()
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
|
||||
# Policy inference
|
||||
with torch.no_grad():
|
||||
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
||||
action = model.act({"states": normalized_obs}, role="policy")[0]
|
||||
action = action.clamp(-1.0, 1.0)
|
||||
|
||||
obs, reward, terminated, truncated, info = serial_runner.step(action)
|
||||
episode_reward += reward.item()
|
||||
step += 1
|
||||
|
||||
# Sync digital twin with real sensor data.
|
||||
serial_runner._sync_viz()
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
||||
viewer.sync()
|
||||
|
||||
if step % 25 == 0:
|
||||
state = serial_runner._read_state()
|
||||
logger.debug(
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1),
|
||||
motor_enc=state["encoder_count"],
|
||||
pend_deg=round(state["pendulum_angle"], 1),
|
||||
)
|
||||
|
||||
# Check for safety / disconnection.
|
||||
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
|
||||
logger.error(
|
||||
"safety_stop",
|
||||
reboot=info.get("reboot_detected", False),
|
||||
motor_limit=info.get("motor_limit_exceeded", False),
|
||||
)
|
||||
serial_runner._send("M0")
|
||||
break
|
||||
|
||||
if terminated.any() or truncated.any():
|
||||
logger.info(
|
||||
"episode_done", episode=episode, steps=step,
|
||||
total_reward=round(episode_reward, 2),
|
||||
reason="terminated" if terminated.any() else "truncated",
|
||||
)
|
||||
# Auto-reset for next episode.
|
||||
obs, _ = serial_runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
|
||||
# Real-time pacing is handled by serial_runner.step() (dt sleep).
|
||||
|
||||
serial_runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
64
scripts/motor_sysid.py
Normal file
64
scripts/motor_sysid.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Unified CLI for motor-only system identification.
|
||||
|
||||
Usage:
|
||||
python scripts/motor_sysid.py capture --duration 20
|
||||
python scripts/motor_sysid.py optimize --recording assets/motor/recordings/<file>.npz
|
||||
python scripts/motor_sysid.py visualize --recording assets/motor/recordings/<file>.npz
|
||||
python scripts/motor_sysid.py export --result assets/motor/motor_sysid_result.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||
print(
|
||||
"Motor System Identification\n"
|
||||
"===========================\n"
|
||||
"Usage: python scripts/motor_sysid.py <command> [options]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" capture Record motor trajectory under PRBS excitation\n"
|
||||
" optimize Run CMA-ES to fit motor parameters\n"
|
||||
" visualize Plot real vs simulated motor response\n"
|
||||
" export Write tuned MJCF + robot.yaml files\n"
|
||||
"\n"
|
||||
"Workflow:\n"
|
||||
" 1. Flash sysid firmware to ESP32 (motor-only, no limits)\n"
|
||||
" 2. python scripts/motor_sysid.py capture --duration 20\n"
|
||||
" 3. python scripts/motor_sysid.py optimize --recording <file>.npz\n"
|
||||
" 4. python scripts/motor_sysid.py visualize --recording <file>.npz\n"
|
||||
"\n"
|
||||
"Run '<command> --help' for command-specific options."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
command = sys.argv[1]
|
||||
sys.argv = [f"motor_sysid {command}"] + sys.argv[2:]
|
||||
|
||||
if command == "capture":
|
||||
from src.sysid.motor.capture import main as cmd_main
|
||||
elif command == "optimize":
|
||||
from src.sysid.motor.optimize import main as cmd_main
|
||||
elif command == "visualize":
|
||||
from src.sysid.motor.visualize import main as cmd_main
|
||||
elif command == "export":
|
||||
from src.sysid.motor.export import main as cmd_main
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("Available commands: capture, optimize, visualize, export")
|
||||
sys.exit(1)
|
||||
|
||||
cmd_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user