Files
RL-Sim-Framework/scripts/eval.py
Victor Mylle b37cd26690 feat: sim2real domain randomization + reward fixes for rotary cartpole
Close the sim2real gap for the Furuta pendulum (swings up but can't
balance on hardware). Root causes were (a) no domain randomization, so
the policy overfit one deterministic sim instance, and (b) reward design
flaws that produced degenerate policies.

Domain randomization (runner-level, backend-agnostic):
- BaseRunner: domain_rand config; per-env action-delay buffer (latency),
  Gaussian qpos/qvel sensor noise, per-env dynamics-scale sampling
  (friction/damping/torque), resampled per episode. Sensor noise per step.
- privileged_obs/privileged_dim expose normalized DR factors (mu) for RMA.
- step() now uses clean state for reward/termination, noisy state for the
  observation the policy sees.
- MuJoCoRunner: applies per-env friction/damping/torque scales.
- robot.py: compute_motor_force gains friction/damping scale args.
- Configs: DR blocks for mujoco (full) and mjx (delay+noise); clean
  defaults for mujoco_single/serial; noise/delay anchored to recordings.

Reward fixes (rotary_cartpole):
- Shift upright reward to [0,1] (was [-1,1]) + alive_bonus, so surviving
  always beats ending early (kills the "suicide into the limit" policy).
- Add balance_bonus * upright * stillness so reward requires upright AND
  near-zero pendulum velocity (kills the "spin in full loops" policy).

Deploy:
- eval.py load_policy reconstructs the history/adaptation encoder
  (auto-detects its dim from the checkpoint) so DR+embedding policies load.

Fixes:
- MuJoCoRunner._sim_reset referenced self._env (typo) -> self.env, which
  was breaking every rotary-cartpole reset.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 20:48:25 +02:00

412 lines
15 KiB
Python

"""Evaluate a trained policy on real hardware (or in simulation).
Loads a checkpoint and runs the policy in a closed loop. For real
hardware the serial runner talks to the ESP32; for sim it uses the
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
in both modes.
Usage (real hardware):
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Usage (simulation):
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Controls:
Space — pause / resume policy (motor stops while paused)
R — reset environment
Esc — quit
"""
import math
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from gymnasium import spaces
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from skrl.resources.preprocessors.torch import RunningStandardScaler
from src.core.registry import build_env
from src.models.mlp import SharedMLP
logger = structlog.get_logger()
# ── keyboard state ───────────────────────────────────────────────────
_reset_flag = [False]
_paused = [False]
_quit_flag = [False]
def _key_callback(keycode: int) -> None:
"""Called by MuJoCo viewer on key press."""
if keycode == 32: # GLFW_KEY_SPACE
_paused[0] = not _paused[0]
elif keycode == 82: # GLFW_KEY_R
_reset_flag[0] = True
elif keycode == 256: # GLFW_KEY_ESCAPE
_quit_flag[0] = True
# ── checkpoint loading ───────────────────────────────────────────────
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
"""Infer hidden layer sizes from a SharedMLP state dict."""
sizes = []
i = 0
while f"net.{i}.weight" in state_dict:
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
i += 2 # skip activation layers (ELU)
return tuple(sizes)
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
"""Return the history/adaptation encoder output dim, if present.
Lets eval reconstruct an embedding policy without knowing the training
embedding_dim/latent_dim — read it straight from the saved weights.
"""
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
if key in state_dict:
return state_dict[key].shape[0]
return None
def load_policy(
checkpoint_path: str,
observation_space: spaces.Space,
action_space: spaces.Space,
device: torch.device = torch.device("cpu"),
history_length: int = 0,
rma_mode: str = "none",
raw_obs_dim: int = 0,
) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
For DR + history-embedding policies (history_length > 0) or RMA deploy
policies (rma_mode="deploy"), the history/adaptation encoder must be
reconstructed too — its output dim is read back from the saved weights.
Returns:
(model, state_preprocessor) ready for inference.
"""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Infer architecture from saved weights.
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
enc_out = _infer_encoder_out_dim(ckpt["policy"])
# Reconstruct model — pass through the encoder config so a DR+embedding
# checkpoint rebuilds the history encoder with matching dimensions.
model = SharedMLP(
observation_space=observation_space,
action_space=action_space,
device=device,
hidden_sizes=hidden_sizes,
history_length=history_length,
rma_mode=rma_mode,
raw_obs_dim=raw_obs_dim,
embedding_dim=enc_out or 32, # legacy "none" + history
latent_dim=enc_out or 8, # RMA deploy adaptation module
)
model.load_state_dict(ckpt["policy"])
model.eval()
# Reconstruct observation normalizer.
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
# Freeze the normalizer — don't update stats during eval.
state_preprocessor.training = False
logger.info(
"checkpoint_loaded",
path=checkpoint_path,
hidden_sizes=hidden_sizes,
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
)
return model, state_preprocessor
# ── action arrow overlay ─────────────────────────────────────────────
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
pos = data.xpos[body_id].copy()
pos[2] += 0.02
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
# ── main loops ───────────────────────────────────────────────────────
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None)
if checkpoint_path is None:
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
sys.exit(1)
# Resolve relative paths against original working directory.
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
if not Path(checkpoint_path).exists():
logger.error("checkpoint_not_found", path=checkpoint_path)
sys.exit(1)
if runner_name == "serial":
_eval_serial(cfg, env_name, checkpoint_path)
else:
_eval_sim(cfg, env_name, checkpoint_path)
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy in MuJoCo simulation with viewer."""
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
device = runner.device
model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device,
history_length=runner.config.history_length,
rma_mode=runner.config.rma_mode,
raw_obs_dim=runner.env.observation_space.shape[0],
)
mj_model = runner._model
mj_data = runner._data[0]
dt_ctrl = runner.config.dt * runner.config.substeps
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="simulation",
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = runner.step(action)
episode_reward += reward.item()
step += 1
# Sync viewer
mujoco.mj_forward(mj_model, mj_data)
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 50 == 0:
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
for i in range(mj_model.njnt)}
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1), **joints,
)
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
time.sleep(dt_ctrl)
runner.close()
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
device = serial_runner.device
model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
history_length=serial_runner.config.history_length,
rma_mode=serial_runner.config.rma_mode,
raw_obs_dim=serial_runner.env.observation_space.shape[0],
)
# Set up digital-twin MuJoCo model for visualization.
serial_runner._ensure_viz_model()
mj_model = serial_runner._viz_model
mj_data = serial_runner._viz_data
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = serial_runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="real hardware (serial)",
port=serial_runner.config.port,
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
obs, _ = serial_runner.reset()
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
serial_runner._send("M0") # safety: stop motor while paused
serial_runner._sync_viz()
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = serial_runner.step(action)
episode_reward += reward.item()
step += 1
# Sync digital twin with real sensor data.
serial_runner._sync_viz()
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 25 == 0:
state = serial_runner._read_state()
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1),
motor_enc=state["encoder_count"],
pend_deg=round(state["pendulum_angle"], 1),
)
# Check for safety / disconnection.
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
logger.error(
"safety_stop",
reboot=info.get("reboot_detected", False),
motor_limit=info.get("motor_limit_exceeded", False),
)
serial_runner._send("M0")
break
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
# Auto-reset for next episode.
obs, _ = serial_runner.reset()
step = 0
episode += 1
episode_reward = 0.0
# Real-time pacing is handled by serial_runner.step() (dt sleep).
serial_runner.close()
if __name__ == "__main__":
main()