Close the sim2real gap for the Furuta pendulum (swings up but can't balance on hardware). Root causes were (a) no domain randomization, so the policy overfit one deterministic sim instance, and (b) reward design flaws that produced degenerate policies. Domain randomization (runner-level, backend-agnostic): - BaseRunner: domain_rand config; per-env action-delay buffer (latency), Gaussian qpos/qvel sensor noise, per-env dynamics-scale sampling (friction/damping/torque), resampled per episode. Sensor noise per step. - privileged_obs/privileged_dim expose normalized DR factors (mu) for RMA. - step() now uses clean state for reward/termination, noisy state for the observation the policy sees. - MuJoCoRunner: applies per-env friction/damping/torque scales. - robot.py: compute_motor_force gains friction/damping scale args. - Configs: DR blocks for mujoco (full) and mjx (delay+noise); clean defaults for mujoco_single/serial; noise/delay anchored to recordings. Reward fixes (rotary_cartpole): - Shift upright reward to [0,1] (was [-1,1]) + alive_bonus, so surviving always beats ending early (kills the "suicide into the limit" policy). - Add balance_bonus * upright * stillness so reward requires upright AND near-zero pendulum velocity (kills the "spin in full loops" policy). Deploy: - eval.py load_policy reconstructs the history/adaptation encoder (auto-detects its dim from the checkpoint) so DR+embedding policies load. Fixes: - MuJoCoRunner._sim_reset referenced self._env (typo) -> self.env, which was breaking every rotary-cartpole reset. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
412 lines
15 KiB
Python
412 lines
15 KiB
Python
"""Evaluate a trained policy on real hardware (or in simulation).
|
|
|
|
Loads a checkpoint and runs the policy in a closed loop. For real
|
|
hardware the serial runner talks to the ESP32; for sim it uses the
|
|
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
|
|
in both modes.
|
|
|
|
Usage (real hardware):
|
|
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
|
|
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
|
|
|
Usage (simulation):
|
|
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
|
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
|
|
|
Controls:
|
|
Space — pause / resume policy (motor stops while paused)
|
|
R — reset environment
|
|
Esc — quit
|
|
"""
|
|
|
|
import math
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Ensure project root is on sys.path
|
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import hydra
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
import structlog
|
|
import torch
|
|
from gymnasium import spaces
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
|
|
|
from src.core.registry import build_env
|
|
from src.models.mlp import SharedMLP
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ── keyboard state ───────────────────────────────────────────────────
|
|
_reset_flag = [False]
|
|
_paused = [False]
|
|
_quit_flag = [False]
|
|
|
|
|
|
def _key_callback(keycode: int) -> None:
|
|
"""Called by MuJoCo viewer on key press."""
|
|
if keycode == 32: # GLFW_KEY_SPACE
|
|
_paused[0] = not _paused[0]
|
|
elif keycode == 82: # GLFW_KEY_R
|
|
_reset_flag[0] = True
|
|
elif keycode == 256: # GLFW_KEY_ESCAPE
|
|
_quit_flag[0] = True
|
|
|
|
|
|
# ── checkpoint loading ───────────────────────────────────────────────
|
|
|
|
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
|
"""Infer hidden layer sizes from a SharedMLP state dict."""
|
|
sizes = []
|
|
i = 0
|
|
while f"net.{i}.weight" in state_dict:
|
|
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
|
|
i += 2 # skip activation layers (ELU)
|
|
return tuple(sizes)
|
|
|
|
|
|
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
|
|
"""Return the history/adaptation encoder output dim, if present.
|
|
|
|
Lets eval reconstruct an embedding policy without knowing the training
|
|
embedding_dim/latent_dim — read it straight from the saved weights.
|
|
"""
|
|
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
|
|
if key in state_dict:
|
|
return state_dict[key].shape[0]
|
|
return None
|
|
|
|
|
|
def load_policy(
|
|
checkpoint_path: str,
|
|
observation_space: spaces.Space,
|
|
action_space: spaces.Space,
|
|
device: torch.device = torch.device("cpu"),
|
|
history_length: int = 0,
|
|
rma_mode: str = "none",
|
|
raw_obs_dim: int = 0,
|
|
) -> tuple[SharedMLP, RunningStandardScaler]:
|
|
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
|
|
|
For DR + history-embedding policies (history_length > 0) or RMA deploy
|
|
policies (rma_mode="deploy"), the history/adaptation encoder must be
|
|
reconstructed too — its output dim is read back from the saved weights.
|
|
|
|
Returns:
|
|
(model, state_preprocessor) ready for inference.
|
|
"""
|
|
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
|
# Infer architecture from saved weights.
|
|
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
|
enc_out = _infer_encoder_out_dim(ckpt["policy"])
|
|
|
|
# Reconstruct model — pass through the encoder config so a DR+embedding
|
|
# checkpoint rebuilds the history encoder with matching dimensions.
|
|
model = SharedMLP(
|
|
observation_space=observation_space,
|
|
action_space=action_space,
|
|
device=device,
|
|
hidden_sizes=hidden_sizes,
|
|
history_length=history_length,
|
|
rma_mode=rma_mode,
|
|
raw_obs_dim=raw_obs_dim,
|
|
embedding_dim=enc_out or 32, # legacy "none" + history
|
|
latent_dim=enc_out or 8, # RMA deploy adaptation module
|
|
)
|
|
model.load_state_dict(ckpt["policy"])
|
|
model.eval()
|
|
|
|
# Reconstruct observation normalizer.
|
|
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
|
|
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
|
|
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
|
|
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
|
|
# Freeze the normalizer — don't update stats during eval.
|
|
state_preprocessor.training = False
|
|
|
|
logger.info(
|
|
"checkpoint_loaded",
|
|
path=checkpoint_path,
|
|
hidden_sizes=hidden_sizes,
|
|
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
|
|
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
|
|
)
|
|
return model, state_preprocessor
|
|
|
|
|
|
# ── action arrow overlay ─────────────────────────────────────────────
|
|
|
|
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
|
"""Draw an arrow showing applied torque direction."""
|
|
if abs(action_val) < 0.01 or model.nu == 0:
|
|
return
|
|
|
|
jnt_id = model.actuator_trnid[0, 0]
|
|
body_id = model.jnt_bodyid[jnt_id]
|
|
pos = data.xpos[body_id].copy()
|
|
pos[2] += 0.02
|
|
|
|
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
|
arrow_len = 0.08 * action_val
|
|
direction = axis * np.sign(arrow_len)
|
|
|
|
z = direction / (np.linalg.norm(direction) + 1e-8)
|
|
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
|
x = np.cross(up, z)
|
|
x /= np.linalg.norm(x) + 1e-8
|
|
y = np.cross(z, x)
|
|
mat = np.column_stack([x, y, z]).flatten()
|
|
|
|
rgba = np.array(
|
|
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
|
mujoco.mjv_initGeom(
|
|
geom,
|
|
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
|
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
|
pos=pos,
|
|
mat=mat,
|
|
rgba=rgba,
|
|
)
|
|
viewer.user_scn.ngeom += 1
|
|
|
|
|
|
# ── main loops ───────────────────────────────────────────────────────
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
choices = HydraConfig.get().runtime.choices
|
|
env_name = choices.get("env", "cartpole")
|
|
runner_name = choices.get("runner", "mujoco_single")
|
|
|
|
checkpoint_path = cfg.get("checkpoint", None)
|
|
if checkpoint_path is None:
|
|
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
|
|
sys.exit(1)
|
|
|
|
# Resolve relative paths against original working directory.
|
|
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
|
|
if not Path(checkpoint_path).exists():
|
|
logger.error("checkpoint_not_found", path=checkpoint_path)
|
|
sys.exit(1)
|
|
|
|
if runner_name == "serial":
|
|
_eval_serial(cfg, env_name, checkpoint_path)
|
|
else:
|
|
_eval_sim(cfg, env_name, checkpoint_path)
|
|
|
|
|
|
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|
"""Evaluate policy in MuJoCo simulation with viewer."""
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
runner_dict["num_envs"] = 1
|
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
|
|
|
device = runner.device
|
|
model, preprocessor = load_policy(
|
|
checkpoint_path, runner.observation_space, runner.action_space, device,
|
|
history_length=runner.config.history_length,
|
|
rma_mode=runner.config.rma_mode,
|
|
raw_obs_dim=runner.env.observation_space.shape[0],
|
|
)
|
|
|
|
mj_model = runner._model
|
|
mj_data = runner._data[0]
|
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
|
|
|
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode = 0
|
|
episode_reward = 0.0
|
|
|
|
logger.info(
|
|
"eval_started",
|
|
env=env_name,
|
|
mode="simulation",
|
|
checkpoint=Path(checkpoint_path).name,
|
|
controls="Space=pause, R=reset, Esc=quit",
|
|
)
|
|
|
|
while viewer.is_running() and not _quit_flag[0]:
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
logger.info("reset", episode=episode)
|
|
|
|
if _paused[0]:
|
|
viewer.sync()
|
|
time.sleep(0.05)
|
|
continue
|
|
|
|
# Policy inference
|
|
with torch.no_grad():
|
|
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
|
action = model.act({"states": normalized_obs}, role="policy")[0]
|
|
action = action.clamp(-1.0, 1.0)
|
|
|
|
obs, reward, terminated, truncated, info = runner.step(action)
|
|
episode_reward += reward.item()
|
|
step += 1
|
|
|
|
# Sync viewer
|
|
mujoco.mj_forward(mj_model, mj_data)
|
|
viewer.user_scn.ngeom = 0
|
|
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
|
viewer.sync()
|
|
|
|
if step % 50 == 0:
|
|
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
|
|
for i in range(mj_model.njnt)}
|
|
logger.debug(
|
|
"step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action[0, 0].item(), 2),
|
|
ep_reward=round(episode_reward, 1), **joints,
|
|
)
|
|
|
|
if terminated.any() or truncated.any():
|
|
logger.info(
|
|
"episode_done", episode=episode, steps=step,
|
|
total_reward=round(episode_reward, 2),
|
|
reason="terminated" if terminated.any() else "truncated",
|
|
)
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
|
|
time.sleep(dt_ctrl)
|
|
|
|
runner.close()
|
|
|
|
|
|
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
|
|
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
|
|
|
env = build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
|
|
|
|
device = serial_runner.device
|
|
model, preprocessor = load_policy(
|
|
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
|
|
history_length=serial_runner.config.history_length,
|
|
rma_mode=serial_runner.config.rma_mode,
|
|
raw_obs_dim=serial_runner.env.observation_space.shape[0],
|
|
)
|
|
|
|
# Set up digital-twin MuJoCo model for visualization.
|
|
serial_runner._ensure_viz_model()
|
|
mj_model = serial_runner._viz_model
|
|
mj_data = serial_runner._viz_data
|
|
|
|
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode = 0
|
|
episode_reward = 0.0
|
|
|
|
logger.info(
|
|
"eval_started",
|
|
env=env_name,
|
|
mode="real hardware (serial)",
|
|
port=serial_runner.config.port,
|
|
checkpoint=Path(checkpoint_path).name,
|
|
controls="Space=pause, R=reset, Esc=quit",
|
|
)
|
|
|
|
while viewer.is_running() and not _quit_flag[0]:
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
serial_runner._send("M0")
|
|
serial_runner._drive_to_center()
|
|
serial_runner._wait_for_pendulum_still()
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
logger.info("reset", episode=episode)
|
|
|
|
if _paused[0]:
|
|
serial_runner._send("M0") # safety: stop motor while paused
|
|
serial_runner._sync_viz()
|
|
viewer.sync()
|
|
time.sleep(0.05)
|
|
continue
|
|
|
|
# Policy inference
|
|
with torch.no_grad():
|
|
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
|
action = model.act({"states": normalized_obs}, role="policy")[0]
|
|
action = action.clamp(-1.0, 1.0)
|
|
|
|
obs, reward, terminated, truncated, info = serial_runner.step(action)
|
|
episode_reward += reward.item()
|
|
step += 1
|
|
|
|
# Sync digital twin with real sensor data.
|
|
serial_runner._sync_viz()
|
|
viewer.user_scn.ngeom = 0
|
|
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
|
viewer.sync()
|
|
|
|
if step % 25 == 0:
|
|
state = serial_runner._read_state()
|
|
logger.debug(
|
|
"step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action[0, 0].item(), 2),
|
|
ep_reward=round(episode_reward, 1),
|
|
motor_enc=state["encoder_count"],
|
|
pend_deg=round(state["pendulum_angle"], 1),
|
|
)
|
|
|
|
# Check for safety / disconnection.
|
|
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
|
|
logger.error(
|
|
"safety_stop",
|
|
reboot=info.get("reboot_detected", False),
|
|
motor_limit=info.get("motor_limit_exceeded", False),
|
|
)
|
|
serial_runner._send("M0")
|
|
break
|
|
|
|
if terminated.any() or truncated.any():
|
|
logger.info(
|
|
"episode_done", episode=episode, steps=step,
|
|
total_reward=round(episode_reward, 2),
|
|
reason="terminated" if terminated.any() else "truncated",
|
|
)
|
|
# Auto-reset for next episode.
|
|
obs, _ = serial_runner.reset()
|
|
step = 0
|
|
episode += 1
|
|
episode_reward = 0.0
|
|
|
|
# Real-time pacing is handled by serial_runner.step() (dt sleep).
|
|
|
|
serial_runner.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|