diff --git a/assets/rotary_cartpole/hardware.yaml b/assets/rotary_cartpole/hardware.yaml
new file mode 100644
index 0000000..a968c91
--- /dev/null
+++ b/assets/rotary_cartpole/hardware.yaml
@@ -0,0 +1,23 @@
+# Rotary cartpole (Furuta pendulum) — real hardware config.
+# Describes the physical device for the SerialRunner.
+# Robot-specific constants that don't belong in the runner config
+# (which is machine-specific: port, baud) or the env config
+# (which is task-specific: rewards, max_steps).
+
+encoder:
+ ppr: 11 # pulses per revolution (before quadrature)
+ gear_ratio: 30.0 # gearbox ratio
+ # counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
+
+safety:
+ max_motor_angle_deg: 90.0 # hard termination limit (0 = disabled)
+ soft_limit_deg: 40.0 # progressive penalty ramp starts here
+
+reset:
+ drive_speed: 80 # PWM magnitude for bang-bang drive-to-center
+ deadband: 15 # encoder count threshold to consider "centered"
+ drive_timeout: 3.0 # seconds before giving up on drive-to-center
+ settle_angle_deg: 2.0 # pendulum angle threshold for "still" (degrees)
+ settle_vel_dps: 5.0 # pendulum velocity threshold (deg/s)
+ settle_duration: 0.5 # how long pendulum must stay still (seconds)
+ settle_timeout: 30.0 # give up waiting after this (seconds)
diff --git a/assets/rotary_cartpole/recordings/capture_20260311_215608.npz b/assets/rotary_cartpole/recordings/capture_20260311_215608.npz
new file mode 100644
index 0000000..850edae
Binary files /dev/null and b/assets/rotary_cartpole/recordings/capture_20260311_215608.npz differ
diff --git a/assets/rotary_cartpole/robot.yaml b/assets/rotary_cartpole/robot.yaml
index d18686a..3323c23 100644
--- a/assets/rotary_cartpole/robot.yaml
+++ b/assets/rotary_cartpole/robot.yaml
@@ -1,19 +1,20 @@
-# Rotary cartpole (Furuta pendulum) — robot hardware config.
-# Lives next to the URDF so all robot-specific settings are in one place.
+# Tuned robot config — generated by src.sysid.optimize
+# Original: robot.yaml
+# Run `python -m src.sysid.visualize` to compare real vs sim.
urdf: rotary_cartpole.urdf
-
actuators:
- - joint: motor_joint
- type: motor # direct torque control
- gear: 0.064 # stall torque @ 58.8% PWM: 0.108 × 150/255 = 0.064 N·m
- ctrl_range: [-1.0, 1.0]
- damping: 0.003 # viscous back-EMF only (small)
- filter_tau: 0.03 # mechanical time constant ~30ms (37mm gearmotor)
-
+- joint: motor_joint
+ type: motor
+ gear: 0.176692
+ ctrl_range:
+ - -1.0
+ - 1.0
+ damping: 0.009505
+ filter_tau: 0.040906
joints:
motor_joint:
- armature: 0.0001 # reflected rotor inertia: ~1e-7 × 30² = 9e-5 kg·m²
- frictionloss: 0.03 # disabled — may slow MJX (constraint-based solver)
+ armature: 0.001389
+ frictionloss: 0.002179
pendulum_joint:
- damping: 0.0001 # bearing friction
+ damping: 6.1e-05
diff --git a/assets/rotary_cartpole/rotary_cartpole.urdf b/assets/rotary_cartpole/rotary_cartpole.urdf
index 327c023..cb7f392 100644
--- a/assets/rotary_cartpole/rotary_cartpole.urdf
+++ b/assets/rotary_cartpole/rotary_cartpole.urdf
@@ -1,106 +1,80 @@
-
+
-
-
-
-
-
+
-
-
-
+
+
+
-
+
-
+
-
+
-
+
-
-
-
+
+
-
-
-
-
-
+
+
+
-
+
-
+
-
+
-
+
-
-
-
-
-
-
-
-
+
+
+
+
+
+
-
-
-
-
-
-
+
+
+
-
+
-
+
-
+
-
+
-
-
-
-
-
-
-
+
+
+
+
+
-
-
+
\ No newline at end of file
diff --git a/assets/rotary_cartpole/sysid_result.json b/assets/rotary_cartpole/sysid_result.json
new file mode 100644
index 0000000..7673ddc
--- /dev/null
+++ b/assets/rotary_cartpole/sysid_result.json
@@ -0,0 +1,70 @@
+{
+ "best_params": {
+ "arm_mass": 0.012391951282440451,
+ "arm_com_x": 0.014950488360794875,
+ "arm_com_y": 0.006089886527968399,
+ "arm_com_z": 0.004470745447817278,
+ "pendulum_mass": 0.035508993892747365,
+ "pendulum_com_x": 0.06432778588634695,
+ "pendulum_com_y": -0.05999895841669392,
+ "pendulum_com_z": 0.0008769789937631209,
+ "pendulum_ixx": 3.139576982078822e-05,
+ "pendulum_iyy": 9.431951659638859e-06,
+ "pendulum_izz": 4.07315891863556e-05,
+ "pendulum_ixy": -1.8892943833253423e-06,
+ "actuator_gear": 0.17669161390939517,
+ "actuator_filter_tau": 0.040905643692382504,
+ "motor_damping": 0.009504542103348917,
+ "pendulum_damping": 6.128535042404019e-05,
+ "motor_armature": 0.0013894759540138252,
+ "motor_frictionloss": 0.002179448047511452
+ },
+ "best_cost": 0.7471380533090072,
+ "recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260311_215608.npz",
+ "param_names": [
+ "arm_mass",
+ "arm_com_x",
+ "arm_com_y",
+ "arm_com_z",
+ "pendulum_mass",
+ "pendulum_com_x",
+ "pendulum_com_y",
+ "pendulum_com_z",
+ "pendulum_ixx",
+ "pendulum_iyy",
+ "pendulum_izz",
+ "pendulum_ixy",
+ "actuator_gear",
+ "actuator_filter_tau",
+ "motor_damping",
+ "pendulum_damping",
+ "motor_armature",
+ "motor_frictionloss"
+ ],
+ "defaults": {
+ "arm_mass": 0.01,
+ "arm_com_x": 5e-05,
+ "arm_com_y": 0.0065,
+ "arm_com_z": 0.00563,
+ "pendulum_mass": 0.015,
+ "pendulum_com_x": 0.1583,
+ "pendulum_com_y": -0.0983,
+ "pendulum_com_z": 0.0,
+ "pendulum_ixx": 6.16e-06,
+ "pendulum_iyy": 6.16e-06,
+ "pendulum_izz": 1.23e-05,
+ "pendulum_ixy": 6.1e-06,
+ "actuator_gear": 0.064,
+ "actuator_filter_tau": 0.03,
+ "motor_damping": 0.003,
+ "pendulum_damping": 0.0001,
+ "motor_armature": 0.0001,
+ "motor_frictionloss": 0.03
+ },
+ "timestamp": "2026-03-11T22:08:04.782736",
+ "history_summary": {
+ "first_cost": 3.909456214944022,
+ "final_cost": 0.7471380533090072,
+ "generations": 200
+ }
+}
\ No newline at end of file
diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml
index 19a666b..65049af 100644
--- a/configs/env/rotary_cartpole.yaml
+++ b/configs/env/rotary_cartpole.yaml
@@ -1,3 +1,10 @@
max_steps: 1000
robot_path: assets/rotary_cartpole
-reward_upright_scale: 1.0
\ No newline at end of file
+reward_upright_scale: 1.0
+speed_penalty_scale: 0.1
+
+# ── HPO search ranges ────────────────────────────────────────────────
+hpo:
+ reward_upright_scale: {min: 0.5, max: 5.0}
+ speed_penalty_scale: {min: 0.01, max: 1.0}
+ max_steps: {values: [500, 1000, 2000]}
\ No newline at end of file
diff --git a/configs/runner/mujoco_single.yaml b/configs/runner/mujoco_single.yaml
new file mode 100644
index 0000000..3b61919
--- /dev/null
+++ b/configs/runner/mujoco_single.yaml
@@ -0,0 +1,7 @@
+# Single-env MuJoCo runner — mimics real hardware timing.
+# dt × substeps = 0.002 × 10 = 0.02 s → 50 Hz control, same as serial runner.
+
+num_envs: 1
+device: cpu
+dt: 0.002
+substeps: 10
diff --git a/configs/runner/serial.yaml b/configs/runner/serial.yaml
new file mode 100644
index 0000000..9b11c72
--- /dev/null
+++ b/configs/runner/serial.yaml
@@ -0,0 +1,11 @@
+# Serial runner — communicates with real hardware over USB/serial.
+# Always single-env, CPU-only. Override port on CLI:
+# python train.py runner=serial runner.port=/dev/ttyUSB0
+
+num_envs: 1
+device: cpu
+port: /dev/cu.usbserial-0001
+baud: 115200
+dt: 0.02 # control loop period (50 Hz)
+no_data_timeout: 2.0 # seconds of silence before declaring disconnect
+encoder_jump_threshold: 200 # encoder tick jump → reboot detection
diff --git a/configs/sysid.yaml b/configs/sysid.yaml
new file mode 100644
index 0000000..492c3c0
--- /dev/null
+++ b/configs/sysid.yaml
@@ -0,0 +1,25 @@
+# System identification defaults.
+# Override via CLI: python -m src.sysid.optimize sysid.max_generations=50
+#
+# These are NOT Hydra config groups — the sysid scripts use argparse.
+# This file serves as documentation and can be loaded by custom wrappers.
+
+capture:
+ port: /dev/cu.usbserial-0001
+ baud: 115200
+ duration: 20.0 # seconds
+ amplitude: 180 # max PWM magnitude (0–255)
+ hold_min_ms: 50 # PRBS min hold time
+ hold_max_ms: 300 # PRBS max hold time
+ dt: 0.02 # sample period (50 Hz)
+
+optimize:
+ sigma0: 0.3 # CMA-ES initial step size (in [0,1] normalised space)
+ population_size: 20 # candidates per generation
+ max_generations: 200 # total generations (~4000 evaluations)
+ sim_dt: 0.002 # MuJoCo physics timestep
+ substeps: 10 # physics substeps per control step (ctrl_dt = 0.02s)
+ pos_weight: 1.0 # MSE weight for angle errors
+ vel_weight: 0.1 # MSE weight for velocity errors
+ window_duration: 0.5 # multiple-shooting window length (s); 0 = open-loop
+ seed: 42
diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml
index d3d1786..5167942 100644
--- a/configs/training/ppo.yaml
+++ b/configs/training/ppo.yaml
@@ -12,5 +12,23 @@ entropy_loss_scale: 0.05
log_interval: 1000
checkpoint_interval: 50000
+initial_log_std: 0.5
+min_log_std: -2.0
+max_log_std: 2.0
+
+record_video_every: 10000
+
# ClearML remote execution (GPU worker)
remote: false
+
+# ── HPO search ranges ────────────────────────────────────────────────
+# Read by scripts/hpo.py — ignored by TrainerConfig during training.
+hpo:
+ learning_rate: {min: 0.00005, max: 0.001}
+ clip_ratio: {min: 0.1, max: 0.3}
+ discount_factor: {min: 0.98, max: 0.999}
+ gae_lambda: {min: 0.9, max: 0.99}
+ entropy_loss_scale: {min: 0.0001, max: 0.1}
+ value_loss_scale: {min: 0.1, max: 1.0}
+ learning_epochs: {min: 2, max: 8, type: int}
+ mini_batches: {values: [2, 4, 8, 16]}
diff --git a/configs/training/ppo_mjx.yaml b/configs/training/ppo_mjx.yaml
index 6f098ab..cbf309a 100644
--- a/configs/training/ppo_mjx.yaml
+++ b/configs/training/ppo_mjx.yaml
@@ -1,22 +1,18 @@
# PPO tuned for MJX (1024+ parallel envs on GPU).
+# Inherits defaults + HPO ranges from ppo.yaml.
# With 1024 envs, each timestep collects 1024 samples, so total_timesteps
# can be much lower than the CPU config.
-hidden_sizes: [128, 128]
-total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps
-rollout_steps: 1024 # PPO batch = 1024 envs × 1024 steps = 1M samples
-learning_epochs: 4
-mini_batches: 32 # keep mini-batch size similar to CPU config (~32K)
-discount_factor: 0.99
-gae_lambda: 0.95
-learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling)
-clip_ratio: 0.2
-value_loss_scale: 0.5
-entropy_loss_scale: 0.05
-log_interval: 100 # log more often (shorter run)
+defaults:
+ - ppo
+ - _self_
+
+total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps
+mini_batches: 32 # keep mini-batch size similar (~32K)
+learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling)
+log_interval: 100
checkpoint_interval: 10000
record_video_every: 10000
-# ClearML remote execution (GPU worker)
remote: false
diff --git a/configs/training/ppo_real.yaml b/configs/training/ppo_real.yaml
new file mode 100644
index 0000000..b165e56
--- /dev/null
+++ b/configs/training/ppo_real.yaml
@@ -0,0 +1,27 @@
+# PPO tuned for single-env real-time training on real hardware.
+# Inherits defaults + HPO ranges from ppo.yaml.
+# ~50 Hz control × 1 env = ~50 timesteps/s.
+# 100k timesteps ≈ 33 minutes of wall-clock training.
+
+defaults:
+ - ppo
+ - _self_
+
+hidden_sizes: [256, 256]
+total_timesteps: 100000
+learning_epochs: 5
+learning_rate: 0.001 # conservative — can't undo real-world damage
+entropy_loss_scale: 0.0001
+log_interval: 1024
+checkpoint_interval: 5000 # frequent saves — can't rewind real hardware
+initial_log_std: -0.5 # moderate initial exploration
+min_log_std: -4.0
+max_log_std: 0.0 # cap σ at 1.0
+
+# Never run real-hardware training remotely
+remote: false
+
+# Tighter HPO ranges for real hardware (override base ppo.yaml ranges)
+hpo:
+ entropy_loss_scale: {min: 0.00005, max: 0.001}
+ learning_rate: {min: 0.0003, max: 0.003}
diff --git a/configs/training/ppo_single.yaml b/configs/training/ppo_single.yaml
new file mode 100644
index 0000000..914a1f9
--- /dev/null
+++ b/configs/training/ppo_single.yaml
@@ -0,0 +1,23 @@
+# PPO tuned for single-env simulation — mimics real hardware training.
+# Inherits defaults + HPO ranges from ppo.yaml.
+# Same 50 Hz control (runner=mujoco_single), 1 env, conservative hypers.
+# Sim runs ~100× faster than real time, so we can afford more timesteps.
+
+defaults:
+ - ppo
+ - _self_
+
+hidden_sizes: [256, 256]
+total_timesteps: 500000
+learning_epochs: 5
+learning_rate: 0.001
+entropy_loss_scale: 0.0001
+log_interval: 1024
+checkpoint_interval: 10000
+initial_log_std: -0.5
+min_log_std: -4.0
+max_log_std: 0.0
+
+record_video_every: 50000
+
+remote: false
diff --git a/requirements.txt b/requirements.txt
index 693a91e..21ad8f7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,4 +11,10 @@ imageio
imageio-ffmpeg
structlog
pyyaml
+pyserial
+cmaes
+matplotlib
+smac>=2.0.0
+ConfigSpace
+hpbandster
pytest
\ No newline at end of file
diff --git a/scripts/hpo.py b/scripts/hpo.py
new file mode 100644
index 0000000..270ace2
--- /dev/null
+++ b/scripts/hpo.py
@@ -0,0 +1,340 @@
+"""Hyperparameter optimization for RL-Framework using ClearML + SMAC3.
+
+Automatically creates a base training task (via Task.create), reads HPO
+search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
+and launches SMAC3 Successive Halving optimization.
+
+Usage:
+ python scripts/hpo.py \
+ --env rotary_cartpole \
+ --runner mujoco_single \
+ --training ppo_single \
+ --queue gpu-queue
+
+ # Or use an existing base task:
+ python scripts/hpo.py --base-task-id
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+import time
+from pathlib import Path
+
+# Ensure project root is on sys.path
+_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
+if _PROJECT_ROOT not in sys.path:
+ sys.path.insert(0, _PROJECT_ROOT)
+
+import structlog
+from clearml import Task
+from clearml.automation import (
+ DiscreteParameterRange,
+ HyperParameterOptimizer,
+ UniformIntegerParameterRange,
+ UniformParameterRange,
+)
+from omegaconf import OmegaConf
+
+logger = structlog.get_logger()
+
+
+def _load_hydra_config(
+ env: str, runner: str, training: str
+) -> dict:
+ """Load and merge Hydra configs to extract HPO ranges.
+
+ We read the YAML files directly (without running Hydra) so this script
+ doesn't need @hydra.main — it's a ClearML optimizer, not a training job.
+ """
+ configs_dir = Path(__file__).resolve().parent.parent / "configs"
+
+ # Load training config (handles defaults: [ppo] inheritance)
+ training_path = configs_dir / "training" / f"{training}.yaml"
+ training_cfg = OmegaConf.load(training_path)
+
+ # If the training config has defaults pointing to a base, load + merge
+ if "defaults" in training_cfg:
+ defaults = OmegaConf.to_container(training_cfg.defaults)
+ base_cfg = OmegaConf.create({})
+ for d in defaults:
+ if isinstance(d, str):
+ base_path = configs_dir / "training" / f"{d}.yaml"
+ if base_path.exists():
+ loaded = OmegaConf.load(base_path)
+ base_cfg = OmegaConf.merge(base_cfg, loaded)
+ # Remove defaults key and merge
+ training_no_defaults = {
+ k: v for k, v in OmegaConf.to_container(training_cfg).items()
+ if k != "defaults"
+ }
+ training_cfg = OmegaConf.merge(base_cfg, OmegaConf.create(training_no_defaults))
+
+ # Load env config
+ env_path = configs_dir / "env" / f"{env}.yaml"
+ env_cfg = OmegaConf.load(env_path) if env_path.exists() else OmegaConf.create({})
+
+ return {
+ "training": OmegaConf.to_container(training_cfg, resolve=True),
+ "env": OmegaConf.to_container(env_cfg, resolve=True),
+ }
+
+
+def _build_hyper_parameters(config: dict) -> list:
+ """Build ClearML parameter ranges from hpo: blocks in config.
+
+ Reads training.hpo and env.hpo dicts and creates appropriate
+ ClearML parameter range objects.
+
+ Each hpo entry can have:
+ {min, max} → UniformParameterRange (float)
+ {min, max, type: int} → UniformIntegerParameterRange
+ {min, max, log: true} → UniformParameterRange with log scale
+ {values: [...]} → DiscreteParameterRange
+ """
+ params = []
+
+ for section in ("training", "env"):
+ hpo_ranges = config.get(section, {}).get("hpo", {})
+ if not hpo_ranges:
+ continue
+
+ for param_name, spec in hpo_ranges.items():
+ hydra_key = f"Hydra/{section}.{param_name}"
+
+ if "values" in spec:
+ params.append(
+ DiscreteParameterRange(hydra_key, values=spec["values"])
+ )
+ elif "min" in spec and "max" in spec:
+ if spec.get("type") == "int":
+ params.append(
+ UniformIntegerParameterRange(
+ hydra_key,
+ min_value=int(spec["min"]),
+ max_value=int(spec["max"]),
+ )
+ )
+ else:
+ step = spec.get("step", None)
+ params.append(
+ UniformParameterRange(
+ hydra_key,
+ min_value=float(spec["min"]),
+ max_value=float(spec["max"]),
+ step_size=step,
+ )
+ )
+ else:
+ logger.warning("skipping_unknown_hpo_spec", param=param_name, spec=spec)
+
+ return params
+
+
+def _create_base_task(
+ env: str, runner: str, training: str, queue: str
+) -> str:
+ """Create a base ClearML task without executing it.
+
+ Uses Task.create() to register a task pointing at scripts/train.py
+ with the correct Hydra overrides. The HPO optimizer will clone this.
+ """
+ script_path = str(Path(__file__).resolve().parent / "train.py")
+ project_root = str(Path(__file__).resolve().parent.parent)
+
+ base_task = Task.create(
+ project_name="RL-Framework",
+ task_name=f"{env}-{runner}-{training} (HPO base)",
+ task_type=Task.TaskTypes.training,
+ script=script_path,
+ working_directory=project_root,
+ argparse_args=[
+ f"env={env}",
+ f"runner={runner}",
+ f"training={training}",
+ ],
+ add_task_init_call=False,
+ )
+
+ # Set docker config
+ base_task.set_base_docker(
+ "registry.kube.optimize/worker-image:latest",
+ docker_setup_bash_script=(
+ "apt-get update && apt-get install -y --no-install-recommends "
+ "libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
+ "&& pip install 'jax[cuda12]' mujoco-mjx"
+ ),
+ )
+
+ req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
+ base_task.set_packages(str(req_file))
+
+ task_id = base_task.id
+ logger.info("base_task_created", task_id=task_id, task_name=base_task.name)
+ return task_id
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Hyperparameter optimization for RL-Framework"
+ )
+ parser.add_argument(
+ "--base-task-id",
+ type=str,
+ default=None,
+ help="Existing ClearML task ID to use as base (skip auto-creation)",
+ )
+ parser.add_argument("--env", type=str, default="rotary_cartpole")
+ parser.add_argument("--runner", type=str, default="mujoco_single")
+ parser.add_argument("--training", type=str, default="ppo_single")
+ parser.add_argument("--queue", type=str, default="gpu-queue")
+ parser.add_argument(
+ "--max-concurrent", type=int, default=2,
+ help="Maximum concurrent trial tasks",
+ )
+ parser.add_argument(
+ "--total-trials", type=int, default=200,
+ help="Total HPO trial budget",
+ )
+ parser.add_argument(
+ "--min-budget", type=int, default=3,
+ help="Minimum budget (epochs) per trial",
+ )
+ parser.add_argument(
+ "--max-budget", type=int, default=81,
+ help="Maximum budget (epochs) for promoted trials",
+ )
+ parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
+ parser.add_argument(
+ "--time-limit-hours", type=float, default=72,
+ help="Total wall-clock time limit in hours",
+ )
+ parser.add_argument(
+ "--objective-metric", type=str, default="Reward / Total reward (mean)",
+ help="ClearML scalar metric title to optimize",
+ )
+ parser.add_argument(
+ "--objective-series", type=str, default=None,
+ help="ClearML scalar metric series (default: same as title)",
+ )
+ parser.add_argument(
+ "--maximize", action="store_true", default=True,
+ help="Maximize the objective (default)",
+ )
+ parser.add_argument(
+ "--minimize", action="store_true", default=False,
+ help="Minimize the objective",
+ )
+ parser.add_argument(
+ "--dry-run", action="store_true",
+ help="Print search space and exit without running",
+ )
+ args = parser.parse_args()
+
+ objective_sign = "min" if args.minimize else "max"
+
+ # ── Load config and build search space ────────────────────────
+ config = _load_hydra_config(args.env, args.runner, args.training)
+ hyper_parameters = _build_hyper_parameters(config)
+
+ if not hyper_parameters:
+ logger.error(
+ "no_hpo_ranges_found",
+ hint="Add 'hpo:' blocks to your training and/or env YAML configs",
+ )
+ return
+
+ if args.dry_run:
+ print(f"\nSearch space ({len(hyper_parameters)} parameters):")
+ for p in hyper_parameters:
+ print(f" {p.name}: {p}")
+ print(f"\nObjective: {args.objective_metric} ({objective_sign})")
+ return
+
+ # ── Create or reuse base task ─────────────────────────────────
+ if args.base_task_id:
+ base_task_id = args.base_task_id
+ logger.info("using_existing_base_task", task_id=base_task_id)
+ else:
+ base_task_id = _create_base_task(
+ args.env, args.runner, args.training, args.queue
+ )
+
+ # ── Initialize ClearML HPO task ───────────────────────────────
+ Task.ignore_requirements("torch")
+ task = Task.init(
+ project_name="RL-Framework",
+ task_name=f"HPO {args.env}-{args.runner}-{args.training}",
+ task_type=Task.TaskTypes.optimizer,
+ reuse_last_task_id=False,
+ )
+ task.set_base_docker(
+ docker_image="registry.kube.optimize/worker-image:latest",
+ docker_arguments=[
+ "-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
+ "-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
+ "-e", "CLEARML_AGENT_FORCE_SYSTEM_SITE_PACKAGES=1",
+ ],
+ )
+ req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
+ task.set_packages(str(req_file))
+
+ # ── Build objective metric ────────────────────────────────────
+ # skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
+ objective_title = args.objective_metric
+ objective_series = args.objective_series or objective_title
+
+ # ── Launch optimizer ──────────────────────────────────────────
+ from src.hpo.smac3 import OptimizerSMAC
+
+ optimizer = HyperParameterOptimizer(
+ base_task_id=base_task_id,
+ hyper_parameters=hyper_parameters,
+ objective_metric_title=objective_title,
+ objective_metric_series=objective_series,
+ objective_metric_sign=objective_sign,
+ optimizer_class=OptimizerSMAC,
+ execution_queue=args.queue,
+ max_number_of_concurrent_tasks=args.max_concurrent,
+ total_max_jobs=args.total_trials,
+ min_iteration_per_job=args.min_budget,
+ max_iteration_per_job=args.max_budget,
+ pool_period_min=1,
+ time_limit_per_job=240, # 4 hours per trial max
+ eta=args.eta,
+ )
+
+ # Send this HPO controller to a remote services worker
+ task.execute_remotely(queue_name="services", exit_process=True)
+
+ # Reporting and time limits
+ optimizer.set_report_period(1)
+ optimizer.set_time_limit(in_minutes=int(args.time_limit_hours * 60))
+
+ # Start and wait
+ optimizer.start()
+ optimizer.wait()
+
+ # Get top experiments
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ top_exp = optimizer.get_top_experiments(top_k=10)
+ logger.info("top_experiments_retrieved", count=len(top_exp))
+ for i, t in enumerate(top_exp):
+ logger.info("top_experiment", rank=i + 1, task_id=t.id, name=t.name)
+ break
+ except Exception as e:
+ logger.warning("retry_get_top_experiments", attempt=attempt + 1, error=str(e))
+ if attempt < max_retries - 1:
+ time.sleep(5.0 * (2 ** attempt))
+ else:
+ logger.error("could_not_retrieve_top_experiments")
+
+ optimizer.stop()
+ logger.info("hpo_complete")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/sysid.py b/scripts/sysid.py
new file mode 100644
index 0000000..7727bfb
--- /dev/null
+++ b/scripts/sysid.py
@@ -0,0 +1,57 @@
+"""Unified CLI for system identification tools.
+
+Usage:
+ python scripts/sysid.py capture --robot-path assets/rotary_cartpole --duration 20
+ python scripts/sysid.py optimize --robot-path assets/rotary_cartpole --recording .npz
+ python scripts/sysid.py visualize --recording .npz
+ python scripts/sysid.py export --robot-path assets/rotary_cartpole --result .json
+"""
+
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+
+# Ensure project root is on sys.path
+_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
+if _PROJECT_ROOT not in sys.path:
+ sys.path.insert(0, _PROJECT_ROOT)
+
+
+def main() -> None:
+ if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
+ print(
+ "Usage: python scripts/sysid.py [options]\n"
+ "\n"
+ "Commands:\n"
+ " capture Record real robot trajectory under PRBS excitation\n"
+ " optimize Run CMA-ES parameter optimization\n"
+ " visualize Plot real vs simulated trajectories\n"
+ " export Write tuned URDF + robot.yaml files\n"
+ "\n"
+ "Run 'python scripts/sysid.py --help' for command-specific options."
+ )
+ sys.exit(0)
+
+ command = sys.argv[1]
+ # Remove the subcommand from argv so the module's argparse works normally
+ sys.argv = [f"sysid {command}"] + sys.argv[2:]
+
+ if command == "capture":
+ from src.sysid.capture import main as cmd_main
+ elif command == "optimize":
+ from src.sysid.optimize import main as cmd_main
+ elif command == "visualize":
+ from src.sysid.visualize import main as cmd_main
+ elif command == "export":
+ from src.sysid.export import main as cmd_main
+ else:
+ print(f"Unknown command: {command}")
+ print("Available commands: capture, optimize, visualize, export")
+ sys.exit(1)
+
+ cmd_main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000..84f5b1b
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,118 @@
+import os
+import pathlib
+import sys
+
+# Ensure project root is on sys.path so `src.*` imports work
+# regardless of which directory the script is invoked from.
+_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
+if _PROJECT_ROOT not in sys.path:
+ sys.path.insert(0, _PROJECT_ROOT)
+
+# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
+if sys.platform == "linux" and "DISPLAY" not in os.environ:
+ os.environ.setdefault("MUJOCO_GL", "osmesa")
+
+import hydra
+import hydra.utils as hydra_utils
+import structlog
+from clearml import Task
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf
+
+from src.core.env import BaseEnv
+from src.core.registry import build_env
+from src.core.runner import BaseRunner
+from src.training.trainer import Trainer, TrainerConfig
+
+logger = structlog.get_logger()
+
+
+# ── runner registry ───────────────────────────────────────────────────
+# Maps Hydra config-group name → (RunnerClass, ConfigClass)
+# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
+
+RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
+ "mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
+ "mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
+ "mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
+ "serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
+}
+
+
+def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
+ """Instantiate the right runner from the Hydra config-group name."""
+ if runner_name not in RUNNER_REGISTRY:
+ raise ValueError(
+ f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
+ )
+ module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
+
+ import importlib
+ mod = importlib.import_module(module_path)
+ runner_cls = getattr(mod, cls_name)
+ config_cls = getattr(mod, cfg_cls_name)
+
+ runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
+ return runner_cls(env=env, config=runner_config)
+
+
+def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
+ """Initialize ClearML task with project structure and tags."""
+ Task.ignore_requirements("torch")
+
+ env_name = choices.get("env", "cartpole")
+ runner_name = choices.get("runner", "mujoco")
+ training_name = choices.get("training", "ppo")
+
+ project = "RL-Framework"
+ task_name = f"{env_name}-{runner_name}-{training_name}"
+ tags = [env_name, runner_name, training_name]
+
+ task = Task.init(project_name=project, task_name=task_name, tags=tags)
+ task.set_base_docker(
+ "registry.kube.optimize/worker-image:latest",
+ docker_setup_bash_script=(
+ "apt-get update && apt-get install -y --no-install-recommends "
+ "libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
+ "&& pip install 'jax[cuda12]' mujoco-mjx"
+ ),
+ )
+
+ req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
+ task.set_packages(str(req_file))
+
+ # Execute remotely if requested and running locally
+ if remote and task.running_locally():
+ logger.info("executing_task_remotely", queue="gpu-queue")
+ task.execute_remotely(queue_name="gpu-queue", exit_process=True)
+
+ return task
+
+
+@hydra.main(version_base=None, config_path="../configs", config_name="config")
+def main(cfg: DictConfig) -> None:
+ choices = HydraConfig.get().runtime.choices
+
+ # ClearML init — must happen before heavy work so remote execution
+ # can take over early. The remote worker re-runs the full script;
+ # execute_remotely() is a no-op on the worker side.
+ training_dict = OmegaConf.to_container(cfg.training, resolve=True)
+ remote = training_dict.pop("remote", False)
+ training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
+ task = _init_clearml(choices, remote=remote)
+
+ env_name = choices.get("env", "cartpole")
+ env = build_env(env_name, cfg)
+ runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
+ trainer_config = TrainerConfig(**training_dict)
+ trainer = Trainer(runner=runner, config=trainer_config)
+
+ try:
+ trainer.train()
+ finally:
+ trainer.close()
+ task.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/viz.py b/scripts/viz.py
new file mode 100644
index 0000000..c04adce
--- /dev/null
+++ b/scripts/viz.py
@@ -0,0 +1,254 @@
+"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
+
+Usage (simulation):
+ mjpython scripts/viz.py env=rotary_cartpole
+ mjpython scripts/viz.py env=cartpole +com=true
+
+Usage (real hardware — digital twin):
+ mjpython scripts/viz.py env=rotary_cartpole runner=serial
+ mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
+
+Controls:
+ Left/Right arrows — apply torque to first actuator
+ R — reset environment
+ Esc / close window — quit
+"""
+import math
+import sys
+import time
+from pathlib import Path
+
+# Ensure project root is on sys.path
+_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
+if _PROJECT_ROOT not in sys.path:
+ sys.path.insert(0, _PROJECT_ROOT)
+
+import hydra
+import mujoco
+import mujoco.viewer
+import numpy as np
+import structlog
+import torch
+from hydra.core.hydra_config import HydraConfig
+from omegaconf import DictConfig, OmegaConf
+
+from src.core.registry import build_env
+from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
+
+logger = structlog.get_logger()
+
+
+# ── keyboard state ───────────────────────────────────────────────────
+_action_val = [0.0] # mutable container shared with callback
+_action_time = [0.0] # timestamp of last key press
+_reset_flag = [False]
+_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
+
+
+def _key_callback(keycode: int) -> None:
+ """Called by MuJoCo on key press & repeat (not release)."""
+ if keycode == 263: # GLFW_KEY_LEFT
+ _action_val[0] = -1.0
+ _action_time[0] = time.time()
+ elif keycode == 262: # GLFW_KEY_RIGHT
+ _action_val[0] = 1.0
+ _action_time[0] = time.time()
+ elif keycode == 82: # GLFW_KEY_R
+ _reset_flag[0] = True
+
+
+def _add_action_arrow(viewer, model, data, action_val: float) -> None:
+ """Draw an arrow on the motor joint showing applied torque direction."""
+ if abs(action_val) < 0.01 or model.nu == 0:
+ return
+
+ # Get the body that the first actuator's joint belongs to
+ jnt_id = model.actuator_trnid[0, 0]
+ body_id = model.jnt_bodyid[jnt_id]
+
+ # Arrow origin: body position
+ pos = data.xpos[body_id].copy()
+ pos[2] += 0.02 # lift slightly above the body
+
+ # Arrow direction: along joint axis in world frame, scaled by action
+ axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
+ arrow_len = 0.08 * action_val
+ direction = axis * np.sign(arrow_len)
+
+ # Build rotation matrix: arrow rendered along local z-axis
+ z = direction / (np.linalg.norm(direction) + 1e-8)
+ up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
+ x = np.cross(up, z)
+ x /= np.linalg.norm(x) + 1e-8
+ y = np.cross(z, x)
+ mat = np.column_stack([x, y, z]).flatten()
+
+ # Color: green = positive, red = negative
+ rgba = np.array(
+ [0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
+ dtype=np.float32,
+ )
+
+ geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
+ mujoco.mjv_initGeom(
+ geom,
+ type=mujoco.mjtGeom.mjGEOM_ARROW,
+ size=np.array([0.008, 0.008, abs(arrow_len)]),
+ pos=pos,
+ mat=mat,
+ rgba=rgba,
+ )
+ viewer.user_scn.ngeom += 1
+
+
+@hydra.main(version_base=None, config_path="../configs", config_name="config")
+def main(cfg: DictConfig) -> None:
+ choices = HydraConfig.get().runtime.choices
+ env_name = choices.get("env", "cartpole")
+ runner_name = choices.get("runner", "mujoco")
+
+ if runner_name == "serial":
+ _main_serial(cfg, env_name)
+ else:
+ _main_sim(cfg, env_name)
+
+
+def _main_sim(cfg: DictConfig, env_name: str) -> None:
+ """Simulation visualization — step MuJoCo physics with keyboard control."""
+
+ # Build env + runner (single env for viz)
+ env = build_env(env_name, cfg)
+ runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
+ runner_dict["num_envs"] = 1
+ runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
+
+ model = runner._model
+ data = runner._data[0]
+
+ # Control period
+ dt_ctrl = runner.config.dt * runner.config.substeps
+
+ # Launch viewer
+ with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
+ # Show CoM / inertia if requested via Hydra override: viz.py +com=true
+ show_com = cfg.get("com", False)
+ if show_com:
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
+
+ obs, _ = runner.reset()
+ step = 0
+
+ logger.info("viewer_started", env=env_name,
+ controls="Left/Right arrows = torque, R = reset")
+
+ while viewer.is_running():
+ # Read action from callback (expires after _ACTION_HOLD_S)
+ if time.time() - _action_time[0] < _ACTION_HOLD_S:
+ action_val = _action_val[0]
+ else:
+ action_val = 0.0
+
+ # Reset on R press
+ if _reset_flag[0]:
+ _reset_flag[0] = False
+ obs, _ = runner.reset()
+ step = 0
+ logger.info("reset")
+
+ # Step through runner
+ action = torch.tensor([[action_val]])
+ obs, reward, terminated, truncated, info = runner.step(action)
+
+ # Sync viewer with action arrow overlay
+ mujoco.mj_forward(model, data)
+ viewer.user_scn.ngeom = 0 # clear previous frame's overlays
+ _add_action_arrow(viewer, model, data, action_val)
+ viewer.sync()
+
+ # Print state
+ if step % 25 == 0:
+ joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
+ for i in range(model.njnt)}
+ logger.debug("step", n=step, reward=round(reward.item(), 3),
+ action=round(action_val, 1), **joints)
+
+ # Real-time pacing
+ time.sleep(dt_ctrl)
+ step += 1
+
+ runner.close()
+
+
+def _main_serial(cfg: DictConfig, env_name: str) -> None:
+ """Digital-twin visualization — mirror real hardware in MuJoCo viewer.
+
+ The MuJoCo model is loaded for rendering only. Joint positions are
+ read from the ESP32 over serial and applied to the model each frame.
+ Keyboard arrows send motor commands to the real robot.
+ """
+ from src.runners.serial import SerialRunner, SerialRunnerConfig
+
+ env = build_env(env_name, cfg)
+ runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
+ serial_runner = SerialRunner(
+ env=env, config=SerialRunnerConfig(**runner_dict)
+ )
+
+ # Load MuJoCo model for visualisation (same URDF the sim uses).
+ serial_runner._ensure_viz_model()
+ model = serial_runner._viz_model
+ data = serial_runner._viz_data
+
+ with mujoco.viewer.launch_passive(
+ model, data, key_callback=_key_callback
+ ) as viewer:
+ # Show CoM / inertia if requested.
+ show_com = cfg.get("com", False)
+ if show_com:
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
+
+ logger.info(
+ "viewer_started",
+ env=env_name,
+ mode="serial (digital twin)",
+ port=serial_runner.config.port,
+ controls="Left/Right arrows = motor command, R = reset",
+ )
+
+ while viewer.is_running():
+ # Read action from keyboard callback.
+ if time.time() - _action_time[0] < _ACTION_HOLD_S:
+ action_val = _action_val[0]
+ else:
+ action_val = 0.0
+
+ # Reset on R press.
+ if _reset_flag[0]:
+ _reset_flag[0] = False
+ serial_runner._send("M0")
+ serial_runner._drive_to_center()
+ serial_runner._wait_for_pendulum_still()
+ logger.info("reset (drive-to-center + settle)")
+
+ # Send motor command to real hardware.
+ motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
+ serial_runner._send(f"M{motor_speed}")
+
+ # Sync MuJoCo model with real sensor data.
+ serial_runner._sync_viz()
+
+ # Render overlays and sync viewer.
+ viewer.user_scn.ngeom = 0
+ _add_action_arrow(viewer, model, data, action_val)
+ viewer.sync()
+
+ # Real-time pacing (~50 Hz, matches serial dt).
+ time.sleep(serial_runner.config.dt)
+
+ serial_runner.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/core/hardware.py b/src/core/hardware.py
new file mode 100644
index 0000000..2d49569
--- /dev/null
+++ b/src/core/hardware.py
@@ -0,0 +1,91 @@
+"""Real-hardware configuration — loaded from hardware.yaml next to robot.yaml.
+
+Provides robot-specific constants for the SerialRunner: encoder specs,
+safety limits, and reset behaviour. Simulation-only robots simply don't
+have a hardware.yaml (the loader returns None).
+
+Usage:
+ hw = load_hardware_config("assets/rotary_cartpole")
+ if hw is not None:
+ counts_per_rev = hw.encoder.ppr * hw.encoder.gear_ratio * 4.0
+"""
+
+import dataclasses
+from pathlib import Path
+
+import structlog
+import yaml
+
+log = structlog.get_logger()
+
+
+@dataclasses.dataclass
+class EncoderConfig:
+ """Rotary encoder parameters."""
+
+ ppr: int = 11 # pulses per revolution (before quadrature)
+ gear_ratio: float = 30.0 # gearbox ratio
+
+ @property
+ def counts_per_rev(self) -> float:
+ """Total encoder counts per output-shaft revolution (quadrature)."""
+ return self.ppr * self.gear_ratio * 4.0
+
+
+@dataclasses.dataclass
+class SafetyConfig:
+ """Safety limits enforced by the runner (not the env)."""
+
+ max_motor_angle_deg: float = 90.0 # hard termination (0 = disabled)
+ soft_limit_deg: float = 40.0 # progressive penalty ramp start
+
+
+@dataclasses.dataclass
+class ResetConfig:
+ """Parameters for the physical reset procedure."""
+
+ drive_speed: int = 80 # PWM for bang-bang drive-to-center
+ deadband: int = 15 # encoder count threshold for "centered"
+ drive_timeout: float = 3.0 # seconds
+
+ settle_angle_deg: float = 2.0 # pendulum angle threshold (degrees)
+ settle_vel_dps: float = 5.0 # pendulum velocity threshold (deg/s)
+ settle_duration: float = 0.5 # seconds the pendulum must stay still
+ settle_timeout: float = 30.0 # give up after this many seconds
+
+
+@dataclasses.dataclass
+class HardwareConfig:
+ """Complete real-hardware description for a robot."""
+
+ encoder: EncoderConfig = dataclasses.field(default_factory=EncoderConfig)
+ safety: SafetyConfig = dataclasses.field(default_factory=SafetyConfig)
+ reset: ResetConfig = dataclasses.field(default_factory=ResetConfig)
+
+
+def load_hardware_config(robot_dir: str | Path) -> HardwareConfig | None:
+ """Load hardware.yaml from a directory.
+
+ Returns None if the file doesn't exist (simulation-only robot).
+ """
+ robot_dir = Path(robot_dir).resolve()
+ yaml_path = robot_dir / "hardware.yaml"
+
+ if not yaml_path.exists():
+ return None
+
+ raw = yaml.safe_load(yaml_path.read_text()) or {}
+
+ encoder = EncoderConfig(**raw.get("encoder", {}))
+ safety = SafetyConfig(**raw.get("safety", {}))
+ reset = ResetConfig(**raw.get("reset", {}))
+
+ config = HardwareConfig(encoder=encoder, safety=safety, reset=reset)
+
+ log.debug(
+ "hardware_config_loaded",
+ robot_dir=str(robot_dir),
+ counts_per_rev=encoder.counts_per_rev,
+ max_motor_angle_deg=safety.max_motor_angle_deg,
+ )
+ return config
diff --git a/src/envs/rotary_cartpole.py b/src/envs/rotary_cartpole.py
index 402e3ce..84eccf3 100644
--- a/src/envs/rotary_cartpole.py
+++ b/src/envs/rotary_cartpole.py
@@ -21,6 +21,7 @@ class RotaryCartPoleConfig(BaseEnvConfig):
"""
# Reward shaping
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
+ speed_penalty_scale: float = 0.1 # penalty for high pendulum velocity near top
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
@@ -69,11 +70,12 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
# Upright reward: -cos(θ) ∈ [-1, +1]
upright = -torch.cos(state.pendulum_angle)
- # Velocity penalties — make spinning expensive but allow swing-up
- pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
- motor_vel_penalty = 0.01 * state.motor_vel ** 2
+ # Penalise high pendulum velocity when near the top (upright).
+ # "nearness" is weighted by how upright the pendulum is (0 at bottom, 1 at top).
+ near_top = torch.clamp(-torch.cos(state.pendulum_angle), min=0.0) # 0‥1
+ speed_penalty = self.config.speed_penalty_scale * near_top * state.pendulum_vel.abs()
- return upright - pend_vel_penalty - motor_vel_penalty
+ return upright * self.config.reward_upright_scale - speed_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
diff --git a/src/hpo/__init__.py b/src/hpo/__init__.py
new file mode 100644
index 0000000..ff5ea99
--- /dev/null
+++ b/src/hpo/__init__.py
@@ -0,0 +1 @@
+"""Hyperparameter optimization — SMAC3 + ClearML Successive Halving."""
diff --git a/src/hpo/smac3.py b/src/hpo/smac3.py
new file mode 100644
index 0000000..b9bb10b
--- /dev/null
+++ b/src/hpo/smac3.py
@@ -0,0 +1,636 @@
+# Requires: pip install smac==2.0.0 ConfigSpace==0.4.20
+import contextlib
+import time
+from collections.abc import Sequence
+from functools import wraps
+from typing import Any
+
+from clearml import Task
+from clearml.automation.optimization import Objective, SearchStrategy
+from clearml.automation.parameters import Parameter
+from clearml.backend_interface.session import SendError
+from ConfigSpace import (
+ CategoricalHyperparameter,
+ ConfigurationSpace,
+ UniformFloatHyperparameter,
+ UniformIntegerHyperparameter,
+)
+from smac import MultiFidelityFacade
+from smac.intensifier.successive_halving import SuccessiveHalving
+from smac.runhistory.dataclasses import TrialInfo, TrialValue
+from smac.scenario import Scenario
+
+
+def retry_on_error(max_retries=5, initial_delay=2.0, backoff=2.0, exceptions=(Exception,)):
+ """Decorator to retry a function on exception with exponential backoff."""
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ delay = initial_delay
+ for attempt in range(max_retries):
+ try:
+ return func(*args, **kwargs)
+ except exceptions:
+ if attempt == max_retries - 1:
+ return None # Return None instead of raising
+ time.sleep(delay)
+ delay *= backoff
+ return None
+
+ return wrapper
+
+ return decorator
+
+
+def _encode_param_name(name: str) -> str:
+ """Encode parameter name for ConfigSpace (replace / with __SLASH__)"""
+ return name.replace("/", "__SLASH__")
+
+
+def _decode_param_name(name: str) -> str:
+ """Decode parameter name back to original (replace __SLASH__ with /)"""
+ return name.replace("__SLASH__", "/")
+
+
+def _convert_param_to_cs(param: Parameter):
+ """
+ Convert a ClearML Parameter into a ConfigSpace hyperparameter,
+ adapted to ConfigSpace>=1.x (no more 'q' argument).
+ """
+ # Encode the name to avoid ConfigSpace issues with special chars like '/'
+ name = _encode_param_name(param.name)
+
+ # Categorical / discrete list
+ if hasattr(param, "values"):
+ return CategoricalHyperparameter(name=name, choices=list(param.values))
+
+ # Numeric range (float or int)
+ if hasattr(param, "min_value") and hasattr(param, "max_value"):
+ min_val = param.min_value
+ max_val = param.max_value
+
+ # Check if this should be treated as integer
+ if isinstance(min_val, int) and isinstance(max_val, int):
+ log = getattr(param, "log_scale", False)
+
+ # Check for step_size for quantization
+ if hasattr(param, "step_size"):
+ sv = int(param.step_size)
+ if sv != 1:
+ # emulate quantization by explicit list of values
+ choices = list(range(min_val, max_val + 1, sv))
+ return CategoricalHyperparameter(name=name, choices=choices)
+
+ # Simple uniform integer range
+ return UniformIntegerHyperparameter(name=name, lower=min_val, upper=max_val, log=log)
+ else:
+ # Treat as float
+ lower, upper = float(min_val), float(max_val)
+ log = getattr(param, "log_scale", False)
+ return UniformFloatHyperparameter(name=name, lower=lower, upper=upper, log=log)
+
+ raise ValueError(f"Unsupported Parameter type: {type(param)}")
+
+
+class OptimizerSMAC(SearchStrategy):
+ """
+ SMAC3-based hyperparameter optimizer, matching OptimizerBOHB interface.
+ """
+
+ def __init__(
+ self,
+ base_task_id: str,
+ hyper_parameters: Sequence[Parameter],
+ objective_metric: Objective,
+ execution_queue: str,
+ num_concurrent_workers: int,
+ min_iteration_per_job: int,
+ max_iteration_per_job: int,
+ total_max_jobs: int,
+ pool_period_min: float = 2.0,
+ time_limit_per_job: float | None = None,
+ compute_time_limit: float | None = None,
+ **smac_kwargs: Any,
+ ):
+ # Initialize base SearchStrategy
+ super().__init__(
+ base_task_id=base_task_id,
+ hyper_parameters=hyper_parameters,
+ objective_metric=objective_metric,
+ execution_queue=execution_queue,
+ num_concurrent_workers=num_concurrent_workers,
+ pool_period_min=pool_period_min,
+ time_limit_per_job=time_limit_per_job,
+ compute_time_limit=compute_time_limit,
+ min_iteration_per_job=min_iteration_per_job,
+ max_iteration_per_job=max_iteration_per_job,
+ total_max_jobs=total_max_jobs,
+ )
+
+ # Expose for internal use (access private attributes from base class)
+ self.execution_queue = self._execution_queue
+ self.min_iterations = min_iteration_per_job
+ self.max_iterations = max_iteration_per_job
+ self.num_concurrent_workers = self._num_concurrent_workers # Fix: access private attribute
+
+ # Objective details
+ # Handle both single objective (string) and multi-objective (list) cases
+ if isinstance(self._objective_metric.title, list):
+ self.metric_title = self._objective_metric.title[0] # Use first objective
+ else:
+ self.metric_title = self._objective_metric.title
+
+ if isinstance(self._objective_metric.series, list):
+ self.metric_series = self._objective_metric.series[0] # Use first series
+ else:
+ self.metric_series = self._objective_metric.series
+
+ # ClearML Objective stores sign as a list, e.g., ['max'] or ['min']
+ objective_sign = getattr(self._objective_metric, "sign", None) or getattr(self._objective_metric, "order", None)
+
+ # Handle list case - extract first element
+ if isinstance(objective_sign, list):
+ objective_sign = objective_sign[0] if objective_sign else "max"
+
+ # Default to max if nothing found
+ if objective_sign is None:
+ objective_sign = "max"
+
+ self.maximize_metric = str(objective_sign).lower() in ("max", "max_global")
+
+ # Build ConfigSpace
+ self.config_space = ConfigurationSpace(seed=42)
+ for p in self._hyper_parameters: # Access private attribute correctly
+ cs_hp = _convert_param_to_cs(p)
+ self.config_space.add(cs_hp)
+
+ # Configure SMAC Scenario
+ scenario = Scenario(
+ configspace=self.config_space,
+ n_trials=self.total_max_jobs,
+ min_budget=float(self.min_iterations),
+ max_budget=float(self.max_iterations),
+ walltime_limit=(self.compute_time_limit * 60) if self.compute_time_limit else None,
+ deterministic=True,
+ )
+
+ # build the Successive Halving intensifier (NOT Hyperband!)
+ # Hyperband runs multiple brackets with different starting budgets - wasteful
+ # Successive Halving: ALL configs start at min_budget, only best get promoted
+ # eta controls the reduction factor (default 3 means keep top 1/3 each round)
+ # eta can be overridden via smac_kwargs from HyperParameterOptimizer
+ eta = smac_kwargs.pop("eta", 3) # Default to 3 if not specified
+ intensifier = SuccessiveHalving(scenario=scenario, eta=eta, **smac_kwargs)
+
+ # now pass that intensifier instance into the facade
+ self.smac = MultiFidelityFacade(
+ scenario=scenario,
+ target_function=lambda config, budget, seed: 0.0,
+ intensifier=intensifier,
+ overwrite=True,
+ )
+
+ # Bookkeeping
+ self.running_tasks = {} # task_id -> trial info
+ self.task_start_times = {} # task_id -> start time (for timeout)
+ self.completed_results = []
+ self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
+ self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
+
+ # Checkpoint continuation tracking: config_key -> {budget: task_id}
+ # Used to find the previous task's checkpoint when promoting a config
+ self.config_to_tasks = {} # config_key -> {budget: task_id}
+
+ # Manual Successive Halving control
+ self.eta = eta
+ self.current_budget = float(self.min_iterations)
+ self.configs_at_budget = {} # budget -> list of (config, score, trial)
+ self.pending_configs = [] # configs waiting to be evaluated at current_budget - list of (trial, prev_task_id)
+ self.evaluated_at_budget = [] # (config, score, trial, task_id) for current budget
+ self.smac_asked_configs = set() # track which configs SMAC has given us
+
+ # Calculate initial rung size for proper Successive Halving
+ # With eta=3: rung sizes are n, n/3, n/9, ...
+ # Total trials = n * (1 + 1/eta + 1/eta^2 + ...) = n * eta/(eta-1) for infinite series
+ # For finite rungs, calculate exactly
+ num_rungs = 1
+ b = float(self.min_iterations)
+ while b * eta <= self.max_iterations:
+ num_rungs += 1
+ b *= eta
+
+ # Sum of geometric series: 1 + 1/eta + 1/eta^2 + ... (num_rungs terms)
+ series_sum = sum(1.0 / (eta**i) for i in range(num_rungs))
+ self.initial_rung_size = int(self.total_max_jobs / series_sum)
+ self.initial_rung_size = max(self.initial_rung_size, self.num_concurrent_workers) # at least num_workers
+ self.configs_needed_for_rung = self.initial_rung_size # how many configs we still need for current rung
+ self.rung_closed = False # whether we've collected all configs for current rung
+
+ @retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
+ def _get_task_safe(self, task_id: str):
+ """Safely get a task with retry logic."""
+ return Task.get_task(task_id=task_id)
+
+ @retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
+ def _launch_task(self, config: dict, budget: float, prev_task_id: str | None = None):
+ """Launch a task with retry logic for robustness.
+
+ Args:
+ config: Hyperparameter configuration dict
+ budget: Number of epochs to train
+ prev_task_id: Optional task ID from previous budget to continue from (checkpoint)
+ """
+ base = self._get_task_safe(task_id=self._base_task_id)
+ if base is None:
+ return None
+
+ clone = Task.clone(
+ source_task=base,
+ name=f"HPO Trial - {base.name}",
+ parent=Task.current_task().id, # Set the current HPO task as parent
+ )
+ # Override hyperparameters
+ for k, v in config.items():
+ # Decode parameter name back to original (with slashes)
+ original_name = _decode_param_name(k)
+ # Convert numpy types to Python built-in types
+ if hasattr(v, "item"): # numpy scalar
+ param_value = v.item()
+ elif isinstance(v, int | float | str | bool):
+ param_value = type(v)(v) # Ensure it's the built-in type
+ else:
+ param_value = v
+ clone.set_parameter(original_name, param_value)
+ # Override epochs budget if multi-fidelity
+ if self.max_iterations != self.min_iterations:
+ clone.set_parameter("Hydra/training.max_epochs", int(budget))
+ else:
+ clone.set_parameter("Hydra/training.max_epochs", int(self.max_iterations))
+
+ # If we have a previous task, pass its ID so the worker can download the checkpoint
+ if prev_task_id:
+ clone.set_parameter("Hydra/training.resume_from_task_id", prev_task_id)
+
+ Task.enqueue(task=clone, queue_name=self.execution_queue)
+ # Track start time for timeout enforcement
+ self.task_start_times[clone.id] = time.time()
+ return clone
+
+ def start(self):
+ controller = Task.current_task()
+ total_launched = 0
+
+ # Keep launching & collecting until budget exhausted
+ while total_launched < self.total_max_jobs:
+ # Check if current budget rung is complete BEFORE asking for new trials
+ # (no running tasks, no pending configs, and we have results for this budget)
+ if not self.running_tasks and not self.pending_configs and self.evaluated_at_budget:
+ # Rung complete! Promote top performers to next budget
+
+ # Store results for this budget
+ self.configs_at_budget[self.current_budget] = self.evaluated_at_budget.copy()
+
+ # Sort by score (best first)
+ sorted_configs = sorted(
+ self.evaluated_at_budget,
+ key=lambda x: x[1], # score
+ reverse=self.maximize_metric,
+ )
+
+ # Print rung results
+ for _i, (_cfg, _score, _tri, _task_id) in enumerate(sorted_configs[:5], 1):
+ pass
+
+ # Move to next budget?
+ next_budget = self.current_budget * self.eta
+ if next_budget <= self.max_iterations:
+ # How many to promote (top 1/eta)
+ n_promote = max(1, len(sorted_configs) // self.eta)
+ promoted = sorted_configs[:n_promote]
+
+ # Update budget and reset for next rung
+ self.current_budget = next_budget
+ self.evaluated_at_budget = []
+ self.configs_needed_for_rung = 0 # promoted configs are all we need
+ self.rung_closed = True # rung is pre-filled with promoted configs
+
+ # Re-queue promoted configs with new budget
+ # Include the previous task ID for checkpoint continuation
+ for _cfg, _score, old_trial, prev_task_id in promoted:
+ new_trial = TrialInfo(
+ config=old_trial.config,
+ instance=old_trial.instance,
+ seed=old_trial.seed,
+ budget=self.current_budget,
+ )
+ # Store as tuple: (trial, prev_task_id)
+ self.pending_configs.append((new_trial, prev_task_id))
+ else:
+ # All budgets complete
+ break
+
+ # Fill pending_configs with new trials ONLY if we haven't closed this rung yet
+ # For the first rung: ask SMAC for initial_rung_size configs total
+ # For subsequent rungs: only use promoted configs (rung is already closed)
+ while (
+ not self.rung_closed
+ and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
+ < self.initial_rung_size
+ and total_launched < self.total_max_jobs
+ ):
+ trial = self.smac.ask()
+ if trial is None:
+ self.rung_closed = True
+ break
+ # Create new trial with forced budget (TrialInfo is frozen, can't modify)
+ trial_with_budget = TrialInfo(
+ config=trial.config,
+ instance=trial.instance,
+ seed=trial.seed,
+ budget=self.current_budget,
+ )
+ cfg_key = str(sorted(trial.config.items()))
+ if cfg_key not in self.smac_asked_configs:
+ self.smac_asked_configs.add(cfg_key)
+ # Store as tuple: (trial, None) - no previous task for new configs
+ self.pending_configs.append((trial_with_budget, None))
+
+ # Check if we've collected enough configs for this rung
+ if (
+ not self.rung_closed
+ and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
+ >= self.initial_rung_size
+ ):
+ self.rung_closed = True
+
+ # Launch pending configs up to concurrent limit
+ while self.pending_configs and len(self.running_tasks) < self.num_concurrent_workers:
+ # Unpack tuple: (trial, prev_task_id)
+ trial, prev_task_id = self.pending_configs.pop(0)
+ t = self._launch_task(trial.config, self.current_budget, prev_task_id=prev_task_id)
+ if t is None:
+ # Launch failed, mark trial as failed and continue
+ # Tell SMAC this trial failed with worst possible score
+ cost = float("inf") if self.maximize_metric else float("-inf")
+ self.smac.tell(trial, TrialValue(cost=cost))
+ total_launched += 1
+ continue
+ self.running_tasks[t.id] = trial
+
+ # Track which task ID was used for this config at this budget
+ cfg_key = str(sorted(trial.config.items()))
+ if cfg_key not in self.config_to_tasks:
+ self.config_to_tasks[cfg_key] = {}
+ self.config_to_tasks[cfg_key][self.current_budget] = t.id
+
+ total_launched += 1
+
+ if not self.running_tasks and not self.pending_configs:
+ break
+
+ # Poll for finished or timed out
+ done = []
+ timed_out = []
+ failed_to_check = []
+ for tid, _tri in self.running_tasks.items():
+ try:
+ task = self._get_task_safe(task_id=tid)
+ if task is None:
+ failed_to_check.append(tid)
+ continue
+
+ st = task.get_status()
+
+ # Check if task completed normally
+ if st == Task.TaskStatusEnum.completed or st in (
+ Task.TaskStatusEnum.failed,
+ Task.TaskStatusEnum.stopped,
+ ):
+ done.append(tid)
+ # Check for timeout (if time limit is set)
+ elif self.time_limit_per_job and tid in self.task_start_times:
+ elapsed_minutes = (time.time() - self.task_start_times[tid]) / 60.0
+ if elapsed_minutes > self.time_limit_per_job:
+ with contextlib.suppress(Exception):
+ task.mark_stopped(force=True)
+ timed_out.append(tid)
+ except Exception:
+ # Don't mark as failed immediately, might be transient
+ # Only mark failed after multiple consecutive failures
+ if not hasattr(self, "_task_check_failures"):
+ self._task_check_failures = {}
+ self._task_check_failures[tid] = self._task_check_failures.get(tid, 0) + 1
+ if self._task_check_failures[tid] >= 5: # 5 consecutive failures
+ failed_to_check.append(tid)
+ del self._task_check_failures[tid]
+
+ # Process tasks that failed to check
+ for tid in failed_to_check:
+ tri = self.running_tasks.pop(tid)
+ if tid in self.task_start_times:
+ del self.task_start_times[tid]
+ # Tell SMAC this trial failed with worst possible score
+ res = float("-inf") if self.maximize_metric else float("inf")
+ cost = -res if self.maximize_metric else res
+ self.smac.tell(tri, TrialValue(cost=cost))
+ self.completed_results.append(
+ {
+ "task_id": tid,
+ "config": tri.config,
+ "budget": tri.budget,
+ "value": res,
+ "failed": True,
+ }
+ )
+ # Store result with task_id for checkpoint tracking
+ self.evaluated_at_budget.append((tri.config, res, tri, tid))
+
+ # Process completed tasks
+ for tid in done:
+ tri = self.running_tasks.pop(tid)
+ if tid in self.task_start_times:
+ del self.task_start_times[tid]
+
+ # Clear any accumulated failures for this task
+ if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
+ del self._task_check_failures[tid]
+
+ task = self._get_task_safe(task_id=tid)
+ if task is None:
+ res = float("-inf") if self.maximize_metric else float("inf")
+ else:
+ res = self._get_objective(task)
+
+ if res is None or res == float("-inf") or res == float("inf"):
+ res = float("-inf") if self.maximize_metric else float("inf")
+
+ cost = -res if self.maximize_metric else res
+ self.smac.tell(tri, TrialValue(cost=cost))
+ self.completed_results.append(
+ {
+ "task_id": tid,
+ "config": tri.config,
+ "budget": tri.budget,
+ "value": res,
+ }
+ )
+
+ # Store result for this budget rung with task_id for checkpoint tracking
+ self.evaluated_at_budget.append((tri.config, res, tri, tid))
+
+ iteration = len(self.completed_results)
+
+ # Always report the trial score (even if it's bad)
+ if res is not None and res != float("-inf") and res != float("inf"):
+ controller.get_logger().report_scalar(
+ title="Optimization", series="trial_score", value=res, iteration=iteration
+ )
+ controller.get_logger().report_scalar(
+ title="Optimization",
+ series="trial_budget",
+ value=tri.budget or self.max_iterations,
+ iteration=iteration,
+ )
+
+ # Update best score tracking based on actual results
+ if res is not None and res != float("-inf") and res != float("inf"):
+ if self.maximize_metric:
+ self.best_score_so_far = max(self.best_score_so_far, res)
+ elif res < self.best_score_so_far:
+ self.best_score_so_far = res
+
+ # Always report best score so far (shows flat line when no improvement)
+ if self.best_score_so_far != float("-inf") and self.best_score_so_far != float("inf"):
+ controller.get_logger().report_scalar(
+ title="Optimization", series="best_score", value=self.best_score_so_far, iteration=iteration
+ )
+
+ # Report running statistics
+ valid_scores = [
+ r["value"]
+ for r in self.completed_results
+ if r["value"] is not None and r["value"] != float("-inf") and r["value"] != float("inf")
+ ]
+ if valid_scores:
+ controller.get_logger().report_scalar(
+ title="Optimization",
+ series="mean_score",
+ value=sum(valid_scores) / len(valid_scores),
+ iteration=iteration,
+ )
+ controller.get_logger().report_scalar(
+ title="Progress",
+ series="completed_trials",
+ value=len(self.completed_results),
+ iteration=iteration,
+ )
+ controller.get_logger().report_scalar(
+ title="Progress", series="running_tasks", value=len(self.running_tasks), iteration=iteration
+ )
+
+ # Process timed out tasks (treat as failed with current objective value)
+ for tid in timed_out:
+ tri = self.running_tasks.pop(tid)
+ if tid in self.task_start_times:
+ del self.task_start_times[tid]
+
+ # Clear any accumulated failures for this task
+ if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
+ del self._task_check_failures[tid]
+
+ # Try to get the last objective value before timeout
+ task = self._get_task_safe(task_id=tid)
+ if task is None:
+ res = float("-inf") if self.maximize_metric else float("inf")
+ else:
+ res = self._get_objective(task)
+
+ if res is None:
+ res = float("-inf") if self.maximize_metric else float("inf")
+ cost = -res if self.maximize_metric else res
+ self.smac.tell(tri, TrialValue(cost=cost))
+ self.completed_results.append(
+ {
+ "task_id": tid,
+ "config": tri.config,
+ "budget": tri.budget,
+ "value": res,
+ "timed_out": True,
+ }
+ )
+
+ # Store timed out result for this budget rung with task_id
+ self.evaluated_at_budget.append((tri.config, res, tri, tid))
+
+ time.sleep(self.pool_period_minutes * 60) # Fix: use correct attribute name from base class
+ if self.compute_time_limit and controller.get_runtime() > self.compute_time_limit * 60:
+ break
+
+ # Finalize
+ self._finalize()
+ return self.completed_results
+
+ @retry_on_error(max_retries=3, initial_delay=2.0, exceptions=(SendError, ConnectionError, KeyError))
+ def _get_objective(self, task: Task):
+ """Get objective metric value with retry logic for robustness."""
+ if task is None:
+ return None
+
+ try:
+ m = task.get_last_scalar_metrics()
+ if not m:
+ return None
+
+ metric_data = m[self.metric_title][self.metric_series]
+
+ # ClearML returns dict with 'last', 'min', 'max' keys representing
+ # the last/min/max values of this series over ALL logged iterations.
+ # For snake_length/train_max: 'last' is the last logged train_max value,
+ # 'max' is the highest train_max ever logged during training.
+
+ # Use 'max' if maximizing (we want the best performance achieved),
+ # 'min' if minimizing, fallback to 'last'
+ if self.maximize_metric and "max" in metric_data:
+ result = metric_data["max"]
+ elif not self.maximize_metric and "min" in metric_data:
+ result = metric_data["min"]
+ else:
+ result = metric_data["last"]
+ return result
+ except (KeyError, Exception):
+ return None
+
+ def _finalize(self):
+ controller = Task.current_task()
+ # Report final best score
+ controller.get_logger().report_text(f"Final best score: {self.best_score_so_far}")
+
+ # Also try to get SMAC's incumbent for comparison
+ try:
+ incumbent = self.smac.intensifier.get_incumbent()
+ if incumbent is not None:
+ runhistory = self.smac.runhistory
+ # Try different ways to get the cost
+ incumbent_cost = None
+ try:
+ incumbent_cost = runhistory.get_cost(incumbent)
+ except Exception:
+ # Fallback: search through runhistory manually
+ for trial_key, trial_value in runhistory.items():
+ trial_config = runhistory.get_config(trial_key.config_id)
+ if trial_config == incumbent and (incumbent_cost is None or trial_value.cost < incumbent_cost):
+ incumbent_cost = trial_value.cost
+
+ if incumbent_cost is not None:
+ score = -incumbent_cost if self.maximize_metric else incumbent_cost
+ controller.get_logger().report_text(f"SMAC incumbent: {incumbent}, score: {score}")
+ controller.upload_artifact(
+ "best_config",
+ {"config": dict(incumbent), "score": score, "our_best_score": self.best_score_so_far},
+ )
+ else:
+ controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
+ except Exception as e:
+ controller.get_logger().report_text(f"Error getting SMAC incumbent: {e}")
+ controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
diff --git a/src/runners/mjx.py b/src/runners/mjx.py
index bf2dcce..a92ac2d 100644
--- a/src/runners/mjx.py
+++ b/src/runners/mjx.py
@@ -214,6 +214,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
"""Offscreen render — copies one env's state from GPU to CPU."""
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
+ self._render_data.ctrl[:] = np.asarray(self._batch_data.ctrl[env_idx])
mujoco.mj_forward(self._mj_model, self._render_data)
if not hasattr(self, "_offscreen_renderer"):
@@ -221,4 +222,10 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._mj_model, width=640, height=480,
)
self._offscreen_renderer.update_scene(self._render_data)
- return self._offscreen_renderer.render()
+ frame = self._offscreen_renderer.render().copy()
+
+ # Import shared overlay helper from mujoco runner
+ from src.runners.mujoco import _draw_action_overlay
+ ctrl_val = float(self._render_data.ctrl[0]) if self._mj_model.nu > 0 else 0.0
+ _draw_action_overlay(frame, ctrl_val)
+ return frame
diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py
index bc395c0..73f071b 100644
--- a/src/runners/mujoco.py
+++ b/src/runners/mujoco.py
@@ -283,4 +283,43 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
)
mujoco.mj_forward(self._model, self._data[env_idx])
self._offscreen_renderer.update_scene(self._data[env_idx])
- return self._offscreen_renderer.render()
\ No newline at end of file
+ frame = self._offscreen_renderer.render().copy()
+
+ # Draw action bar overlay — shows ctrl[0] as a horizontal bar
+ ctrl_val = float(self._data[env_idx].ctrl[0]) if self._model.nu > 0 else 0.0
+ _draw_action_overlay(frame, ctrl_val)
+ return frame
+
+
+def _draw_action_overlay(frame: np.ndarray, action: float) -> None:
+ """Draw an action bar + text on a rendered frame (no OpenCV needed).
+
+ Bar is centered horizontally: green to the right (+), red to the left (-).
+ """
+ h, w = frame.shape[:2]
+
+ # Bar geometry
+ bar_y = h - 30
+ bar_h = 16
+ bar_x_center = w // 2
+ bar_half_w = w // 4 # max half-width of the bar
+ bar_x_left = bar_x_center - bar_half_w
+ bar_x_right = bar_x_center + bar_half_w
+
+ # Background (dark grey)
+ frame[bar_y:bar_y + bar_h, bar_x_left:bar_x_right] = [40, 40, 40]
+
+ # Filled bar
+ fill_len = int(abs(action) * bar_half_w)
+ if action > 0:
+ color = [60, 200, 60] # green
+ x0 = bar_x_center
+ x1 = min(bar_x_center + fill_len, bar_x_right)
+ else:
+ color = [200, 60, 60] # red
+ x1 = bar_x_center
+ x0 = max(bar_x_center - fill_len, bar_x_left)
+ frame[bar_y:bar_y + bar_h, x0:x1] = color
+
+ # Center tick mark (white)
+ frame[bar_y:bar_y + bar_h, bar_x_center - 1:bar_x_center + 1] = [255, 255, 255]
\ No newline at end of file
diff --git a/src/runners/serial.py b/src/runners/serial.py
new file mode 100644
index 0000000..1c38725
--- /dev/null
+++ b/src/runners/serial.py
@@ -0,0 +1,571 @@
+"""Serial runner — real hardware over USB/serial (ESP32).
+
+Implements the BaseRunner interface for a single physical robot.
+All physics come from the real world; the runner translates between
+the ESP32 serial protocol and the qpos/qvel tensors that BaseRunner
+and BaseEnv expect.
+
+Serial protocol (ESP32 firmware):
+ Commands sent TO the ESP32:
+ G — start streaming state lines
+ H — stop streaming
+ M — set motor PWM speed (-255 … 255)
+
+ State lines received FROM the ESP32:
+ S,,,,,,
+ ,,,,
+ ,
+ (12 comma-separated fields after the ``S`` prefix)
+
+A daemon thread continuously reads the serial stream so the control
+loop never blocks on I/O.
+
+Usage:
+ python train.py env=rotary_cartpole runner=serial training=ppo_real
+"""
+
+from __future__ import annotations
+
+import dataclasses
+import logging
+import math
+import threading
+import time
+from typing import Any
+
+import numpy as np
+import torch
+
+from src.core.env import BaseEnv
+from src.core.hardware import HardwareConfig, load_hardware_config
+from src.core.runner import BaseRunner, BaseRunnerConfig
+
+logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass
+class SerialRunnerConfig(BaseRunnerConfig):
+ """Configuration for serial communication with the ESP32."""
+
+ num_envs: int = 1 # always 1 — single physical robot
+ device: str = "cpu"
+
+ port: str = "/dev/cu.usbserial-0001"
+ baud: int = 115200
+ dt: float = 0.02 # control loop period (seconds), 50 Hz
+ no_data_timeout: float = 2.0 # seconds of silence → disconnect
+ encoder_jump_threshold: int = 200 # encoder tick jump → reboot
+
+
+class SerialRunner(BaseRunner[SerialRunnerConfig]):
+ """BaseRunner implementation that talks to real hardware over serial.
+
+ Maps the ESP32 serial protocol to qpos/qvel tensors so the existing
+ RotaryCartPoleEnv (or any compatible env) works unchanged.
+
+ qpos layout: [motor_angle_rad, pendulum_angle_rad]
+ qvel layout: [motor_vel_rad_s, pendulum_vel_rad_s]
+ """
+
+ # ------------------------------------------------------------------
+ # BaseRunner interface
+ # ------------------------------------------------------------------
+
+ @property
+ def num_envs(self) -> int:
+ return 1
+
+ @property
+ def device(self) -> torch.device:
+ return torch.device("cpu")
+
+ def _sim_initialize(self, config: SerialRunnerConfig) -> None:
+ # Load hardware description (encoder, safety, reset params).
+ hw = load_hardware_config(self.env.config.robot_path)
+ if hw is None:
+ raise FileNotFoundError(
+ f"hardware.yaml not found in {self.env.config.robot_path}. "
+ "The serial runner requires a hardware config for encoder, "
+ "safety, and reset parameters."
+ )
+ self._hw: HardwareConfig = hw
+ self._counts_per_rev: float = hw.encoder.counts_per_rev
+ self._max_motor_angle_rad: float = (
+ math.radians(hw.safety.max_motor_angle_deg)
+ if hw.safety.max_motor_angle_deg > 0
+ else 0.0
+ )
+
+ # Joint dimensions for the rotary cartpole (motor + pendulum).
+ self._nq = 2
+ self._nv = 2
+
+ # Import serial here so it's not a hard dependency for sim-only users.
+ import serial as _serial
+
+ self._serial_mod = _serial
+
+ self.ser: _serial.Serial = _serial.Serial(
+ config.port, config.baud, timeout=0.05
+ )
+ time.sleep(2) # Wait for ESP32 boot.
+ self.ser.reset_input_buffer()
+
+ # Internal state tracking.
+ self._rebooted: bool = False
+ self._serial_disconnected: bool = False
+ self._last_esp_ms: int = 0
+ self._last_data_time: float = time.monotonic()
+ self._last_encoder_count: int = 0
+ self._streaming: bool = False
+
+ # Latest parsed state (updated by the reader thread).
+ self._latest_state: dict[str, Any] = {
+ "timestamp_ms": 0,
+ "encoder_count": 0,
+ "rpm": 0.0,
+ "motor_speed": 0,
+ "at_limit": False,
+ "pendulum_angle": 0.0,
+ "pendulum_velocity": 0.0,
+ "target_speed": 0,
+ "braking": False,
+ "enc_vel_cps": 0.0,
+ "pendulum_ok": False,
+ }
+ self._state_lock = threading.Lock()
+ self._state_event = threading.Event()
+
+ # Start background serial reader.
+ self._reader_running = True
+ self._reader_thread = threading.Thread(
+ target=self._serial_reader, daemon=True
+ )
+ self._reader_thread.start()
+
+ # Start streaming.
+ self._send("G")
+ self._streaming = True
+ self._last_data_time = time.monotonic()
+
+ # Track wall-clock time of last step for PPO-gap detection.
+ self._last_step_time: float = time.monotonic()
+
+ def _sim_step(
+ self, actions: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ now = time.monotonic()
+
+ # Detect PPO update gap: if more than 0.5s since last step,
+ # the optimizer was running and no motor commands were sent.
+ # Trigger a full reset so the robot starts from a clean state.
+ gap = now - self._last_step_time
+ if gap > 0.5:
+ logger.info(
+ "PPO update gap detected (%.1f s) — resetting before resuming.",
+ gap,
+ )
+ self._send("M0")
+ all_ids = torch.arange(self.num_envs, device=self.device)
+ self._sim_reset(all_ids)
+ self.step_counts.zero_()
+
+ step_start = time.monotonic()
+
+ # Map normalised action [-1, 1] → PWM [-255, 255].
+ action_val = float(actions[0, 0].clamp(-1.0, 1.0))
+ motor_speed = int(action_val * 255)
+ self._send(f"M{motor_speed}")
+
+ # Enforce dt wall-clock timing.
+ elapsed = time.monotonic() - step_start
+ remaining = self.config.dt - elapsed
+ if remaining > 0:
+ time.sleep(remaining)
+
+ # Read latest sensor data (non-blocking — dt sleep ensures freshness).
+ state = self._read_state()
+
+ motor_angle = (
+ state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ motor_vel = (
+ state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ pendulum_angle = math.radians(state["pendulum_angle"])
+ pendulum_vel = math.radians(state["pendulum_velocity"])
+
+ # Cache motor angle for safety check in step() — avoids a second read.
+ self._last_motor_angle_rad = motor_angle
+ self._last_step_time = time.monotonic()
+
+ qpos = torch.tensor(
+ [[motor_angle, pendulum_angle]], dtype=torch.float32
+ )
+ qvel = torch.tensor(
+ [[motor_vel, pendulum_vel]], dtype=torch.float32
+ )
+ return qpos, qvel
+
+ def _sim_reset(
+ self, env_ids: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # If ESP32 rebooted or disconnected, we can't recover.
+ if self._rebooted or self._serial_disconnected:
+ raise RuntimeError(
+ "ESP32 rebooted or disconnected during training! "
+ "Encoder center is lost. "
+ "Please re-center the motor manually and restart."
+ )
+
+ # Stop motor and restart streaming.
+ self._send("M0")
+ self._send("H")
+ self._streaming = False
+ time.sleep(0.05)
+ self._state_event.clear()
+ self._send("G")
+ self._streaming = True
+ self._last_data_time = time.monotonic()
+ time.sleep(0.05)
+
+ # Physically return the motor to the centre position.
+ self._drive_to_center()
+
+ # Wait until the pendulum settles.
+ self._wait_for_pendulum_still()
+
+ # Refresh data timer so health checks don't false-positive.
+ self._last_data_time = time.monotonic()
+
+ # Read settled state and return as qpos/qvel.
+ state = self._read_state_blocking()
+ motor_angle = (
+ state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ motor_vel = (
+ state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ pendulum_angle = math.radians(state["pendulum_angle"])
+ pendulum_vel = math.radians(state["pendulum_velocity"])
+
+ qpos = torch.tensor(
+ [[motor_angle, pendulum_angle]], dtype=torch.float32
+ )
+ qvel = torch.tensor(
+ [[motor_vel, pendulum_vel]], dtype=torch.float32
+ )
+ return qpos, qvel
+
+ def _sim_close(self) -> None:
+ self._reader_running = False
+ self._streaming = False
+ self._send("H") # Stop streaming.
+ self._send("M0") # Stop motor.
+ time.sleep(0.1)
+ self._reader_thread.join(timeout=1.0)
+ self.ser.close()
+ if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
+ self._offscreen_renderer.close()
+
+ # ------------------------------------------------------------------
+ # MuJoCo digital-twin rendering
+ # ------------------------------------------------------------------
+
+ def _ensure_viz_model(self) -> None:
+ """Lazily load the MuJoCo model for visualisation (digital twin).
+
+ Reuses the same URDF + robot.yaml that the MuJoCoRunner would use,
+ but only for rendering — no physics stepping.
+ """
+ if hasattr(self, "_viz_model"):
+ return
+
+ import mujoco
+ from src.runners.mujoco import MuJoCoRunner
+
+ self._viz_model = MuJoCoRunner._load_model(self.env.robot)
+ self._viz_data = mujoco.MjData(self._viz_model)
+ self._offscreen_renderer = None
+
+ def _sync_viz(self) -> None:
+ """Copy current serial sensor state into the MuJoCo viz model."""
+ import mujoco
+
+ self._ensure_viz_model()
+ state = self._read_state()
+
+ # Set joint positions from serial data.
+ motor_angle = (
+ state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ pendulum_angle = math.radians(state["pendulum_angle"])
+ self._viz_data.qpos[0] = motor_angle
+ self._viz_data.qpos[1] = pendulum_angle
+
+ # Set joint velocities (for any velocity-dependent visuals).
+ motor_vel = (
+ state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ pendulum_vel = math.radians(state["pendulum_velocity"])
+ self._viz_data.qvel[0] = motor_vel
+ self._viz_data.qvel[1] = pendulum_vel
+
+ # Forward kinematics (updates body positions for rendering).
+ mujoco.mj_forward(self._viz_model, self._viz_data)
+
+ def render(self, env_idx: int = 0) -> np.ndarray:
+ """Offscreen render of the digital-twin MuJoCo model.
+
+ Called by VideoRecordingTrainer during training to capture frames.
+ """
+ import mujoco
+
+ self._sync_viz()
+
+ if self._offscreen_renderer is None:
+ self._offscreen_renderer = mujoco.Renderer(
+ self._viz_model, width=640, height=480,
+ )
+ self._offscreen_renderer.update_scene(self._viz_data)
+ return self._offscreen_renderer.render().copy()
+
+ # ------------------------------------------------------------------
+ # Override step() for runner-level safety
+ # ------------------------------------------------------------------
+
+ def step(
+ self, actions: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
+ # Check for ESP32 reboot / disconnect BEFORE stepping.
+ if self._rebooted or self._serial_disconnected:
+ self._send("M0")
+ # Return a terminal observation with penalty.
+ qpos, qvel = self._make_current_state()
+ state = self.env.build_state(qpos, qvel)
+ obs = self.env.compute_observations(state)
+ reward = torch.tensor([[-100.0]])
+ terminated = torch.tensor([[True]])
+ truncated = torch.tensor([[False]])
+ return obs, reward, terminated, truncated, {"reboot_detected": True}
+
+ # Normal step via BaseRunner (calls _sim_step → env logic).
+ obs, rewards, terminated, truncated, info = super().step(actions)
+
+ # Check connection health after stepping.
+ if not self._check_connection_health():
+ self._send("M0")
+ terminated = torch.tensor([[True]])
+ rewards = torch.tensor([[-100.0]])
+ info["reboot_detected"] = True
+
+ # Check motor angle against hard safety limit.
+ # Uses the cached value from _sim_step — no extra serial read.
+ if self._max_motor_angle_rad > 0:
+ motor_angle = abs(getattr(self, "_last_motor_angle_rad", 0.0))
+ if motor_angle >= self._max_motor_angle_rad:
+ self._send("M0")
+ terminated = torch.tensor([[True]])
+ rewards = torch.tensor([[-100.0]])
+ info["motor_limit_exceeded"] = True
+
+ # Always stop motor on episode end.
+ if terminated.any() or truncated.any():
+ self._send("M0")
+
+ return obs, rewards, terminated, truncated, info
+
+ # ------------------------------------------------------------------
+ # Serial helpers
+ # ------------------------------------------------------------------
+
+ def _send(self, cmd: str) -> None:
+ """Send a command to the ESP32."""
+ try:
+ self.ser.write(f"{cmd}\n".encode())
+ except (OSError, self._serial_mod.SerialException):
+ self._serial_disconnected = True
+
+ def _serial_reader(self) -> None:
+ """Background thread: continuously read and parse serial lines."""
+ while self._reader_running:
+ try:
+ if self.ser.in_waiting:
+ line = (
+ self.ser.readline()
+ .decode("utf-8", errors="ignore")
+ .strip()
+ )
+
+ # Detect ESP32 reboot: it prints READY on startup.
+ if line.startswith("READY"):
+ self._rebooted = True
+ logger.critical("ESP32 reboot detected: %s", line)
+ continue
+
+ if line.startswith("S,"):
+ parts = line.split(",")
+ if len(parts) >= 12:
+ esp_ms = int(parts[1])
+ enc = int(parts[2])
+
+ # Detect reboot: timestamp jumped backwards.
+ if (
+ self._last_esp_ms > 5000
+ and esp_ms < self._last_esp_ms - 3000
+ ):
+ self._rebooted = True
+ logger.critical(
+ "ESP32 reboot detected: timestamp"
+ " %d -> %d",
+ self._last_esp_ms,
+ esp_ms,
+ )
+
+ # Detect reboot: encoder snapped to 0 from
+ # a far position.
+ if (
+ abs(self._last_encoder_count)
+ > self.config.encoder_jump_threshold
+ and abs(enc) < 5
+ ):
+ self._rebooted = True
+ logger.critical(
+ "ESP32 reboot detected: encoder"
+ " jumped %d -> %d",
+ self._last_encoder_count,
+ enc,
+ )
+
+ self._last_esp_ms = esp_ms
+ self._last_encoder_count = enc
+ self._last_data_time = time.monotonic()
+
+ parsed: dict[str, Any] = {
+ "timestamp_ms": esp_ms,
+ "encoder_count": enc,
+ "rpm": float(parts[3]),
+ "motor_speed": int(parts[4]),
+ "at_limit": bool(int(parts[5])),
+ "pendulum_angle": float(parts[6]),
+ "pendulum_velocity": float(parts[7]),
+ "target_speed": int(parts[8]),
+ "braking": bool(int(parts[9])),
+ "enc_vel_cps": float(parts[10]),
+ "pendulum_ok": bool(int(parts[11])),
+ }
+ with self._state_lock:
+ self._latest_state = parsed
+ self._state_event.set()
+ else:
+ time.sleep(0.001) # Avoid busy-spinning.
+ except (OSError, self._serial_mod.SerialException) as exc:
+ self._serial_disconnected = True
+ logger.critical("Serial connection lost: %s", exc)
+ break
+
+ def _check_connection_health(self) -> bool:
+ """Return True if the ESP32 connection appears healthy."""
+ if self._serial_disconnected:
+ logger.critical("ESP32 serial connection lost.")
+ return False
+ if (
+ self._streaming
+ and (time.monotonic() - self._last_data_time)
+ > self.config.no_data_timeout
+ ):
+ logger.critical(
+ "No data from ESP32 for %.1f s — possible crash/disconnect.",
+ time.monotonic() - self._last_data_time,
+ )
+ self._rebooted = True
+ return False
+ return True
+
+ def _read_state(self) -> dict[str, Any]:
+ """Return the most recent state from the reader thread (non-blocking).
+
+ The background thread updates at ~50 Hz and `_sim_step` already
+ sleeps for `dt` before calling this, so the data is always fresh.
+ """
+ with self._state_lock:
+ return dict(self._latest_state)
+
+ def _read_state_blocking(self, timeout: float = 0.05) -> dict[str, Any]:
+ """Wait for a fresh sample, then return it.
+
+ Used during reset / settling where we need to guarantee we have
+ a new reading (no prior dt sleep).
+ """
+ self._state_event.clear()
+ self._state_event.wait(timeout=timeout)
+ with self._state_lock:
+ return dict(self._latest_state)
+
+ def _make_current_state(self) -> tuple[torch.Tensor, torch.Tensor]:
+ """Build qpos/qvel from current sensor data (utility)."""
+ state = self._read_state_blocking()
+ motor_angle = (
+ state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ motor_vel = (
+ state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
+ )
+ pendulum_angle = math.radians(state["pendulum_angle"])
+ pendulum_vel = math.radians(state["pendulum_velocity"])
+
+ qpos = torch.tensor(
+ [[motor_angle, pendulum_angle]], dtype=torch.float32
+ )
+ qvel = torch.tensor(
+ [[motor_vel, pendulum_vel]], dtype=torch.float32
+ )
+ return qpos, qvel
+
+ # ------------------------------------------------------------------
+ # Physical reset helpers
+ # ------------------------------------------------------------------
+
+ def _drive_to_center(self) -> None:
+ """Drive the motor back toward encoder=0 using bang-bang control."""
+ rc = self._hw.reset
+ start = time.time()
+ while time.time() - start < rc.drive_timeout:
+ state = self._read_state_blocking()
+ enc = state["encoder_count"]
+ if abs(enc) < rc.deadband:
+ break
+ speed = rc.drive_speed if enc < 0 else -rc.drive_speed
+ self._send(f"M{speed}")
+ time.sleep(0.05)
+ self._send("M0")
+ time.sleep(0.2)
+
+ def _wait_for_pendulum_still(self) -> None:
+ """Block until the pendulum has settled (angle and velocity near zero)."""
+ rc = self._hw.reset
+ stable_since: float | None = None
+ start = time.monotonic()
+
+ while time.monotonic() - start < rc.settle_timeout:
+ state = self._read_state_blocking()
+ angle_ok = abs(state["pendulum_angle"]) < rc.settle_angle_deg
+ vel_ok = abs(state["pendulum_velocity"]) < rc.settle_vel_dps
+
+ if angle_ok and vel_ok:
+ if stable_since is None:
+ stable_since = time.monotonic()
+ elif time.monotonic() - stable_since >= rc.settle_duration:
+ logger.info(
+ "Pendulum settled after %.2f s",
+ time.monotonic() - start,
+ )
+ return
+ else:
+ stable_since = None
+ time.sleep(0.02)
+
+ logger.warning(
+ "Pendulum did not fully settle within %.1f s — proceeding anyway.",
+ rc.settle_timeout,
+ )
diff --git a/src/sysid/__init__.py b/src/sysid/__init__.py
new file mode 100644
index 0000000..9ad0001
--- /dev/null
+++ b/src/sysid/__init__.py
@@ -0,0 +1 @@
+"""System identification — tune simulation parameters to match real hardware."""
diff --git a/src/sysid/capture.py b/src/sysid/capture.py
new file mode 100644
index 0000000..9c54ea4
--- /dev/null
+++ b/src/sysid/capture.py
@@ -0,0 +1,381 @@
+"""Capture a real-robot trajectory under random excitation (PRBS-style).
+
+Connects to the ESP32 over serial, sends random PWM commands to excite
+the system, and records motor + pendulum angles and velocities at ~50 Hz.
+
+Saves a compressed numpy archive (.npz) that the optimizer can replay
+in simulation to fit physics parameters.
+
+Usage:
+ python -m src.sysid.capture \
+ --robot-path assets/rotary_cartpole \
+ --port /dev/cu.usbserial-0001 \
+ --duration 20
+"""
+
+from __future__ import annotations
+
+import argparse
+import math
+import os
+import random
+import threading
+import time
+from datetime import datetime
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import structlog
+import yaml
+
+log = structlog.get_logger()
+
+
+# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
+
+
+def _parse_state_line(line: str) -> dict[str, Any] | None:
+ """Parse an ``S,…`` state line from the ESP32."""
+ if not line.startswith("S,"):
+ return None
+ parts = line.split(",")
+ if len(parts) < 12:
+ return None
+ try:
+ return {
+ "timestamp_ms": int(parts[1]),
+ "encoder_count": int(parts[2]),
+ "rpm": float(parts[3]),
+ "motor_speed": int(parts[4]),
+ "at_limit": bool(int(parts[5])),
+ "pendulum_angle": float(parts[6]),
+ "pendulum_velocity": float(parts[7]),
+ "target_speed": int(parts[8]),
+ "braking": bool(int(parts[9])),
+ "enc_vel_cps": float(parts[10]),
+ "pendulum_ok": bool(int(parts[11])),
+ }
+ except (ValueError, IndexError):
+ return None
+
+
+# ── Background serial reader ─────────────────────────────────────────
+
+
+class _SerialReader:
+ """Minimal background reader for the ESP32 serial stream."""
+
+ def __init__(self, port: str, baud: int = 115200):
+ import serial as _serial
+
+ self._serial_mod = _serial
+ self.ser = _serial.Serial(port, baud, timeout=0.05)
+ time.sleep(2) # Wait for ESP32 boot.
+ self.ser.reset_input_buffer()
+
+ self._latest: dict[str, Any] = {}
+ self._lock = threading.Lock()
+ self._event = threading.Event()
+ self._running = True
+
+ self._thread = threading.Thread(target=self._reader_loop, daemon=True)
+ self._thread.start()
+
+ def _reader_loop(self) -> None:
+ while self._running:
+ try:
+ if self.ser.in_waiting:
+ line = (
+ self.ser.readline()
+ .decode("utf-8", errors="ignore")
+ .strip()
+ )
+ parsed = _parse_state_line(line)
+ if parsed is not None:
+ with self._lock:
+ self._latest = parsed
+ self._event.set()
+ else:
+ time.sleep(0.001)
+ except (OSError, self._serial_mod.SerialException):
+ log.critical("serial_lost")
+ break
+
+ def send(self, cmd: str) -> None:
+ try:
+ self.ser.write(f"{cmd}\n".encode())
+ except (OSError, self._serial_mod.SerialException):
+ log.critical("serial_send_failed", cmd=cmd)
+
+ def read(self) -> dict[str, Any]:
+ with self._lock:
+ return dict(self._latest)
+
+ def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
+ self._event.clear()
+ self._event.wait(timeout=timeout)
+ return self.read()
+
+ def close(self) -> None:
+ self._running = False
+ self.send("H")
+ self.send("M0")
+ time.sleep(0.1)
+ self._thread.join(timeout=1.0)
+ self.ser.close()
+
+
+# ── PRBS excitation signal ───────────────────────────────────────────
+
+
+class _PRBSExcitation:
+ """Random hold-value excitation with configurable amplitude and hold time.
+
+ At each call to ``__call__``, returns the current PWM value.
+ The value is held for a random duration (``hold_min``–``hold_max`` ms),
+ then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
+ """
+
+ def __init__(
+ self,
+ amplitude: int = 180,
+ hold_min_ms: int = 50,
+ hold_max_ms: int = 300,
+ ):
+ self.amplitude = amplitude
+ self.hold_min_ms = hold_min_ms
+ self.hold_max_ms = hold_max_ms
+ self._current: int = 0
+ self._switch_time: float = 0.0
+ self._new_value()
+
+ def _new_value(self) -> None:
+ self._current = random.randint(-self.amplitude, self.amplitude)
+ hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
+ self._switch_time = time.monotonic() + hold_ms / 1000.0
+
+ def __call__(self) -> int:
+ if time.monotonic() >= self._switch_time:
+ self._new_value()
+ return self._current
+
+
+# ── Main capture loop ────────────────────────────────────────────────
+
+
+def capture(
+ robot_path: str | Path,
+ port: str = "/dev/cu.usbserial-0001",
+ baud: int = 115200,
+ duration: float = 20.0,
+ amplitude: int = 180,
+ hold_min_ms: int = 50,
+ hold_max_ms: int = 300,
+ dt: float = 0.02,
+) -> Path:
+ """Run the capture procedure and return the path to the saved .npz file.
+
+ Parameters
+ ----------
+ robot_path : path to robot asset directory (contains hardware.yaml)
+ port : serial port for ESP32
+ baud : baud rate
+ duration : capture duration in seconds
+ amplitude : max PWM magnitude for excitation (0–255)
+ hold_min_ms / hold_max_ms : random hold time range (ms)
+ dt : target sampling period (seconds), default 50 Hz
+ """
+ robot_path = Path(robot_path).resolve()
+
+ # Load hardware config for encoder conversion + safety.
+ hw_yaml = robot_path / "hardware.yaml"
+ if not hw_yaml.exists():
+ raise FileNotFoundError(f"hardware.yaml not found in {robot_path}")
+ raw_hw = yaml.safe_load(hw_yaml.read_text())
+ ppr = raw_hw.get("encoder", {}).get("ppr", 11)
+ gear_ratio = raw_hw.get("encoder", {}).get("gear_ratio", 30.0)
+ counts_per_rev: float = ppr * gear_ratio * 4.0
+ max_motor_deg = raw_hw.get("safety", {}).get("max_motor_angle_deg", 90.0)
+ max_motor_rad = math.radians(max_motor_deg) if max_motor_deg > 0 else 0.0
+
+ # Connect.
+ reader = _SerialReader(port, baud)
+ excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
+
+ # Prepare recording buffers.
+ max_samples = int(duration / dt) + 500 # headroom
+ rec_time = np.zeros(max_samples, dtype=np.float64)
+ rec_action = np.zeros(max_samples, dtype=np.float64)
+ rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
+ rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
+ rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
+ rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
+
+ # Start streaming.
+ reader.send("G")
+ time.sleep(0.1)
+
+ log.info(
+ "capture_starting",
+ port=port,
+ duration=duration,
+ amplitude=amplitude,
+ hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
+ dt=dt,
+ )
+
+ idx = 0
+ t0 = time.monotonic()
+ try:
+ while True:
+ loop_start = time.monotonic()
+ elapsed = loop_start - t0
+ if elapsed >= duration:
+ break
+
+ # Get excitation PWM.
+ pwm = excitation()
+
+ # Safety: reverse/zero if near motor limit.
+ state = reader.read()
+ if state:
+ motor_angle_rad = (
+ state.get("encoder_count", 0) / counts_per_rev * 2.0 * math.pi
+ )
+ if max_motor_rad > 0:
+ margin = max_motor_rad * 0.85 # start braking at 85%
+ if motor_angle_rad > margin and pwm > 0:
+ pwm = -abs(pwm) # reverse
+ elif motor_angle_rad < -margin and pwm < 0:
+ pwm = abs(pwm) # reverse
+
+ # Send command.
+ reader.send(f"M{pwm}")
+
+ # Wait for fresh data.
+ time.sleep(max(0, dt - (time.monotonic() - loop_start) - 0.005))
+ state = reader.read_blocking(timeout=dt)
+
+ if state:
+ enc = state.get("encoder_count", 0)
+ motor_angle = enc / counts_per_rev * 2.0 * math.pi
+ motor_vel = (
+ state.get("enc_vel_cps", 0.0) / counts_per_rev * 2.0 * math.pi
+ )
+ pend_angle = math.radians(state.get("pendulum_angle", 0.0))
+ pend_vel = math.radians(state.get("pendulum_velocity", 0.0))
+
+ # Normalised action: PWM / 255 → [-1, 1]
+ action_norm = pwm / 255.0
+
+ if idx < max_samples:
+ rec_time[idx] = elapsed
+ rec_action[idx] = action_norm
+ rec_motor_angle[idx] = motor_angle
+ rec_motor_vel[idx] = motor_vel
+ rec_pend_angle[idx] = pend_angle
+ rec_pend_vel[idx] = pend_vel
+ idx += 1
+
+ # Progress.
+ if idx % 50 == 0:
+ log.info(
+ "capture_progress",
+ elapsed=f"{elapsed:.1f}/{duration:.0f}s",
+ samples=idx,
+ pwm=pwm,
+ )
+
+ # Pace to dt.
+ remaining = dt - (time.monotonic() - loop_start)
+ if remaining > 0:
+ time.sleep(remaining)
+
+ finally:
+ reader.send("M0")
+ reader.close()
+
+ # Trim to actual sample count.
+ rec_time = rec_time[:idx]
+ rec_action = rec_action[:idx]
+ rec_motor_angle = rec_motor_angle[:idx]
+ rec_motor_vel = rec_motor_vel[:idx]
+ rec_pend_angle = rec_pend_angle[:idx]
+ rec_pend_vel = rec_pend_vel[:idx]
+
+ # Save.
+ recordings_dir = robot_path / "recordings"
+ recordings_dir.mkdir(exist_ok=True)
+ stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ out_path = recordings_dir / f"capture_{stamp}.npz"
+ np.savez_compressed(
+ out_path,
+ time=rec_time,
+ action=rec_action,
+ motor_angle=rec_motor_angle,
+ motor_vel=rec_motor_vel,
+ pendulum_angle=rec_pend_angle,
+ pendulum_vel=rec_pend_vel,
+ )
+
+ log.info(
+ "capture_saved",
+ path=str(out_path),
+ samples=idx,
+ duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
+ )
+ return out_path
+
+
+# ── CLI entry point ──────────────────────────────────────────────────
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Capture a real-robot trajectory for system identification."
+ )
+ parser.add_argument(
+ "--robot-path",
+ type=str,
+ default="assets/rotary_cartpole",
+ help="Path to robot asset directory (contains hardware.yaml)",
+ )
+ parser.add_argument(
+ "--port",
+ type=str,
+ default="/dev/cu.usbserial-0001",
+ help="Serial port for ESP32",
+ )
+ parser.add_argument("--baud", type=int, default=115200)
+ parser.add_argument(
+ "--duration", type=float, default=20.0, help="Capture duration (s)"
+ )
+ parser.add_argument(
+ "--amplitude", type=int, default=180, help="Max PWM magnitude (0–255)"
+ )
+ parser.add_argument(
+ "--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
+ )
+ parser.add_argument(
+ "--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
+ )
+ parser.add_argument(
+ "--dt", type=float, default=0.02, help="Sample period (s)"
+ )
+ args = parser.parse_args()
+
+ capture(
+ robot_path=args.robot_path,
+ port=args.port,
+ baud=args.baud,
+ duration=args.duration,
+ amplitude=args.amplitude,
+ hold_min_ms=args.hold_min_ms,
+ hold_max_ms=args.hold_max_ms,
+ dt=args.dt,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/sysid/export.py b/src/sysid/export.py
new file mode 100644
index 0000000..4dc02a9
--- /dev/null
+++ b/src/sysid/export.py
@@ -0,0 +1,186 @@
+"""Export tuned parameters to URDF and robot.yaml files.
+
+Reads the original files, injects the optimised parameter values,
+and writes ``rotary_cartpole_tuned.urdf`` + ``robot_tuned.yaml``
+alongside the originals in the robot asset directory.
+"""
+
+from __future__ import annotations
+
+import copy
+import json
+import shutil
+import xml.etree.ElementTree as ET
+from pathlib import Path
+
+import structlog
+import yaml
+
+log = structlog.get_logger()
+
+
+def export_tuned_files(
+ robot_path: str | Path,
+ params: dict[str, float],
+) -> tuple[Path, Path]:
+ """Write tuned URDF and robot.yaml files.
+
+ Parameters
+ ----------
+ robot_path : robot asset directory (contains robot.yaml + *.urdf)
+ params : dict of parameter name → tuned value (from optimizer)
+
+ Returns
+ -------
+ (tuned_urdf_path, tuned_robot_yaml_path)
+ """
+ robot_path = Path(robot_path).resolve()
+
+ # ── Load originals ───────────────────────────────────────────
+ robot_yaml_path = robot_path / "robot.yaml"
+ robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
+ urdf_path = robot_path / robot_cfg["urdf"]
+
+ # ── Tune URDF ────────────────────────────────────────────────
+ tree = ET.parse(urdf_path)
+ root = tree.getroot()
+
+ for link in root.iter("link"):
+ link_name = link.get("name", "")
+ inertial = link.find("inertial")
+ if inertial is None:
+ continue
+
+ if link_name == "arm":
+ _set_mass(inertial, params.get("arm_mass"))
+ _set_com(
+ inertial,
+ params.get("arm_com_x"),
+ params.get("arm_com_y"),
+ params.get("arm_com_z"),
+ )
+
+ elif link_name == "pendulum":
+ _set_mass(inertial, params.get("pendulum_mass"))
+ _set_com(
+ inertial,
+ params.get("pendulum_com_x"),
+ params.get("pendulum_com_y"),
+ params.get("pendulum_com_z"),
+ )
+ _set_inertia(
+ inertial,
+ ixx=params.get("pendulum_ixx"),
+ iyy=params.get("pendulum_iyy"),
+ izz=params.get("pendulum_izz"),
+ ixy=params.get("pendulum_ixy"),
+ )
+
+ # Write tuned URDF.
+ tuned_urdf_name = urdf_path.stem + "_tuned" + urdf_path.suffix
+ tuned_urdf_path = robot_path / tuned_urdf_name
+
+ # Preserve the XML declaration and original formatting as much as possible.
+ ET.indent(tree, space=" ")
+ tree.write(str(tuned_urdf_path), xml_declaration=True, encoding="unicode")
+ log.info("tuned_urdf_written", path=str(tuned_urdf_path))
+
+ # ── Tune robot.yaml ──────────────────────────────────────────
+ tuned_cfg = copy.deepcopy(robot_cfg)
+
+ # Point to the tuned URDF.
+ tuned_cfg["urdf"] = tuned_urdf_name
+
+ # Update actuator parameters.
+ if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
+ act = tuned_cfg["actuators"][0]
+ if "actuator_gear" in params:
+ act["gear"] = round(params["actuator_gear"], 6)
+ if "actuator_filter_tau" in params:
+ act["filter_tau"] = round(params["actuator_filter_tau"], 6)
+ if "motor_damping" in params:
+ act["damping"] = round(params["motor_damping"], 6)
+
+ # Update joint overrides.
+ if "joints" not in tuned_cfg:
+ tuned_cfg["joints"] = {}
+
+ if "motor_joint" not in tuned_cfg["joints"]:
+ tuned_cfg["joints"]["motor_joint"] = {}
+ mj = tuned_cfg["joints"]["motor_joint"]
+ if "motor_armature" in params:
+ mj["armature"] = round(params["motor_armature"], 6)
+ if "motor_frictionloss" in params:
+ mj["frictionloss"] = round(params["motor_frictionloss"], 6)
+
+ if "pendulum_joint" not in tuned_cfg["joints"]:
+ tuned_cfg["joints"]["pendulum_joint"] = {}
+ pj = tuned_cfg["joints"]["pendulum_joint"]
+ if "pendulum_damping" in params:
+ pj["damping"] = round(params["pendulum_damping"], 6)
+
+ # Write tuned robot.yaml.
+ tuned_yaml_path = robot_path / "robot_tuned.yaml"
+
+ # Add a header comment.
+ header = (
+ "# Tuned robot config — generated by src.sysid.optimize\n"
+ "# Original: robot.yaml\n"
+ "# Run `python -m src.sysid.visualize` to compare real vs sim.\n\n"
+ )
+ tuned_yaml_path.write_text(
+ header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
+ )
+ log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
+
+ return tuned_urdf_path, tuned_yaml_path
+
+
+# ── XML helpers (shared with rollout.py) ─────────────────────────────
+
+
+def _set_mass(inertial: ET.Element, mass: float | None) -> None:
+ if mass is None:
+ return
+ mass_el = inertial.find("mass")
+ if mass_el is not None:
+ mass_el.set("value", str(mass))
+
+
+def _set_com(
+ inertial: ET.Element,
+ x: float | None,
+ y: float | None,
+ z: float | None,
+) -> None:
+ origin = inertial.find("origin")
+ if origin is None:
+ return
+ xyz = origin.get("xyz", "0 0 0").split()
+ if x is not None:
+ xyz[0] = str(x)
+ if y is not None:
+ xyz[1] = str(y)
+ if z is not None:
+ xyz[2] = str(z)
+ origin.set("xyz", " ".join(xyz))
+
+
+def _set_inertia(
+ inertial: ET.Element,
+ ixx: float | None = None,
+ iyy: float | None = None,
+ izz: float | None = None,
+ ixy: float | None = None,
+ iyz: float | None = None,
+ ixz: float | None = None,
+) -> None:
+ ine = inertial.find("inertia")
+ if ine is None:
+ return
+ for attr, val in [
+ ("ixx", ixx), ("iyy", iyy), ("izz", izz),
+ ("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
+ ]:
+ if val is not None:
+ ine.set(attr, str(val))
diff --git a/src/sysid/optimize.py b/src/sysid/optimize.py
new file mode 100644
index 0000000..1ffc03f
--- /dev/null
+++ b/src/sysid/optimize.py
@@ -0,0 +1,376 @@
+"""CMA-ES optimiser — fit simulation parameters to a real-robot recording.
+
+Minimises the trajectory-matching cost between a MuJoCo rollout and a
+recorded real-robot sequence. Uses the ``cmaes`` package (pure-Python
+CMA-ES with native box-constraint support).
+
+Usage:
+ python -m src.sysid.optimize \
+ --robot-path assets/rotary_cartpole \
+ --recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
+
+ # Shorter run for testing:
+ python -m src.sysid.optimize \
+ --robot-path assets/rotary_cartpole \
+ --recording .npz \
+ --max-generations 10 --population-size 8
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import time
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import structlog
+
+from src.sysid.rollout import (
+ ROTARY_CARTPOLE_PARAMS,
+ ParamSpec,
+ bounds_arrays,
+ defaults_vector,
+ params_to_dict,
+ rollout,
+ windowed_rollout,
+)
+
+log = structlog.get_logger()
+
+
+# ── Cost function ────────────────────────────────────────────────────
+
+
+def _angle_diff(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ """Shortest signed angle difference, handling wrapping."""
+ return np.arctan2(np.sin(a - b), np.cos(a - b))
+
+
+def _check_inertia_valid(params: dict[str, float]) -> bool:
+ """Quick reject: pendulum inertia tensor must be positive-definite."""
+ ixx = params.get("pendulum_ixx", 6.16e-06)
+ iyy = params.get("pendulum_iyy", 6.16e-06)
+ izz = params.get("pendulum_izz", 1.23e-05)
+ ixy = params.get("pendulum_ixy", 6.10e-06)
+ det_xy = ixx * iyy - ixy * ixy
+ return det_xy > 0 and ixx > 0 and iyy > 0 and izz > 0
+
+
+def _compute_trajectory_cost(
+ sim: dict[str, np.ndarray],
+ recording: dict[str, np.ndarray],
+ pos_weight: float = 1.0,
+ vel_weight: float = 0.1,
+) -> float:
+ """Weighted MSE between sim and real trajectories."""
+ motor_err = _angle_diff(sim["motor_angle"], recording["motor_angle"])
+ pend_err = _angle_diff(sim["pendulum_angle"], recording["pendulum_angle"])
+ motor_vel_err = sim["motor_vel"] - recording["motor_vel"]
+ pend_vel_err = sim["pendulum_vel"] - recording["pendulum_vel"]
+
+ return float(
+ pos_weight * np.mean(motor_err**2)
+ + pos_weight * np.mean(pend_err**2)
+ + vel_weight * np.mean(motor_vel_err**2)
+ + vel_weight * np.mean(pend_vel_err**2)
+ )
+
+
+def cost_function(
+ params_vec: np.ndarray,
+ recording: dict[str, np.ndarray],
+ robot_path: Path,
+ specs: list[ParamSpec],
+ sim_dt: float = 0.002,
+ substeps: int = 10,
+ pos_weight: float = 1.0,
+ vel_weight: float = 0.1,
+ window_duration: float = 0.5,
+) -> float:
+ """Compute trajectory-matching cost for a candidate parameter vector.
+
+ Uses **multiple-shooting** (windowed rollout): the recording is split
+ into short windows (default 0.5 s). Each window is initialised from
+ the real qpos/qvel, so early errors don’t compound across the full
+ trajectory. This gives a much smoother cost landscape for CMA-ES.
+
+ Set ``window_duration=0`` to fall back to the original open-loop
+ single-shot rollout (not recommended).
+ """
+ params = params_to_dict(params_vec, specs)
+
+ if not _check_inertia_valid(params):
+ return 1e6
+
+ try:
+ if window_duration > 0:
+ sim = windowed_rollout(
+ robot_path=robot_path,
+ params=params,
+ recording=recording,
+ window_duration=window_duration,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ else:
+ sim = rollout(
+ robot_path=robot_path,
+ params=params,
+ actions=recording["action"],
+ timesteps=recording["time"],
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ except Exception as exc:
+ log.warning("rollout_failed", error=str(exc))
+ return 1e6
+
+ return _compute_trajectory_cost(sim, recording, pos_weight, vel_weight)
+
+
+# ── CMA-ES optimisation loop ────────────────────────────────────────
+
+
+def optimize(
+ robot_path: str | Path,
+ recording_path: str | Path,
+ specs: list[ParamSpec] | None = None,
+ sigma0: float = 0.3,
+ population_size: int = 20,
+ max_generations: int = 1000,
+ sim_dt: float = 0.002,
+ substeps: int = 10,
+ pos_weight: float = 1.0,
+ vel_weight: float = 0.1,
+ window_duration: float = 0.5,
+ seed: int = 42,
+) -> dict:
+ """Run CMA-ES optimisation and return results.
+
+ Returns a dict with:
+ best_params: dict[str, float]
+ best_cost: float
+ history: list of (generation, best_cost) tuples
+ recording: str (path used)
+ specs: list of param names
+ """
+ from cmaes import CMA
+
+ robot_path = Path(robot_path).resolve()
+ recording_path = Path(recording_path).resolve()
+
+ if specs is None:
+ specs = ROTARY_CARTPOLE_PARAMS
+
+ # Load recording.
+ recording = dict(np.load(recording_path))
+ n_samples = len(recording["time"])
+ duration = recording["time"][-1] - recording["time"][0]
+ n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
+ log.info(
+ "recording_loaded",
+ path=str(recording_path),
+ samples=n_samples,
+ duration=f"{duration:.1f}s",
+ window_duration=f"{window_duration}s",
+ n_windows=n_windows,
+ )
+
+ # Initial point (defaults) — normalised to [0, 1] for CMA-ES.
+ lo, hi = bounds_arrays(specs)
+ x0 = defaults_vector(specs)
+
+ # Normalise to [0, 1] for the optimizer (better conditioned).
+ span = hi - lo
+ span[span == 0] = 1.0 # avoid division by zero
+
+ def to_normed(x: np.ndarray) -> np.ndarray:
+ return (x - lo) / span
+
+ def from_normed(x_n: np.ndarray) -> np.ndarray:
+ return x_n * span + lo
+
+ x0_normed = to_normed(x0)
+ bounds_normed = np.column_stack(
+ [np.zeros(len(specs)), np.ones(len(specs))]
+ )
+
+ optimizer = CMA(
+ mean=x0_normed,
+ sigma=sigma0,
+ bounds=bounds_normed,
+ population_size=population_size,
+ seed=seed,
+ )
+
+ best_cost = float("inf")
+ best_params_vec = x0.copy()
+ history: list[tuple[int, float]] = []
+
+ log.info(
+ "cmaes_starting",
+ n_params=len(specs),
+ population=population_size,
+ max_gens=max_generations,
+ sigma0=sigma0,
+ )
+
+ t0 = time.monotonic()
+
+ for gen in range(max_generations):
+ solutions = []
+ for _ in range(optimizer.population_size):
+ x_normed = optimizer.ask()
+ x_natural = from_normed(x_normed)
+
+ # Clip to bounds (CMA-ES can slightly exceed with sampling noise).
+ x_natural = np.clip(x_natural, lo, hi)
+
+ c = cost_function(
+ x_natural,
+ recording,
+ robot_path,
+ specs,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ pos_weight=pos_weight,
+ vel_weight=vel_weight,
+ window_duration=window_duration,
+ )
+ solutions.append((x_normed, c))
+
+ if c < best_cost:
+ best_cost = c
+ best_params_vec = x_natural.copy()
+
+ optimizer.tell(solutions)
+ history.append((gen, best_cost))
+
+ elapsed = time.monotonic() - t0
+ if gen % 5 == 0 or gen == max_generations - 1:
+ log.info(
+ "cmaes_generation",
+ gen=gen,
+ best_cost=f"{best_cost:.6f}",
+ elapsed=f"{elapsed:.1f}s",
+ gen_best=f"{min(c for _, c in solutions):.6f}",
+ )
+
+ total_time = time.monotonic() - t0
+ best_params = params_to_dict(best_params_vec, specs)
+
+ log.info(
+ "cmaes_finished",
+ best_cost=f"{best_cost:.6f}",
+ total_time=f"{total_time:.1f}s",
+ evaluations=max_generations * population_size,
+ )
+
+ # Log parameter comparison.
+ defaults = params_to_dict(defaults_vector(specs), specs)
+ for name in best_params:
+ d = defaults[name]
+ b = best_params[name]
+ change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
+ log.info(
+ "param_result",
+ name=name,
+ default=f"{d:.6g}",
+ tuned=f"{b:.6g}",
+ change=f"{change_pct:+.1f}%",
+ )
+
+ return {
+ "best_params": best_params,
+ "best_cost": best_cost,
+ "history": history,
+ "recording": str(recording_path),
+ "param_names": [s.name for s in specs],
+ "defaults": {s.name: s.default for s in specs},
+ "timestamp": datetime.now().isoformat(),
+ }
+
+
+# ── CLI entry point ──────────────────────────────────────────────────
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Fit simulation parameters to a real-robot recording (CMA-ES)."
+ )
+ parser.add_argument(
+ "--robot-path",
+ type=str,
+ default="assets/rotary_cartpole",
+ help="Path to robot asset directory",
+ )
+ parser.add_argument(
+ "--recording",
+ type=str,
+ required=True,
+ help="Path to .npz recording file",
+ )
+ parser.add_argument("--sigma0", type=float, default=0.3)
+ parser.add_argument("--population-size", type=int, default=20)
+ parser.add_argument("--max-generations", type=int, default=200)
+ parser.add_argument("--sim-dt", type=float, default=0.002)
+ parser.add_argument("--substeps", type=int, default=10)
+ parser.add_argument("--pos-weight", type=float, default=1.0)
+ parser.add_argument("--vel-weight", type=float, default=0.1)
+ parser.add_argument(
+ "--window-duration",
+ type=float,
+ default=0.5,
+ help="Shooting window length in seconds (0 = open-loop, default 0.5)",
+ )
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument(
+ "--no-export",
+ action="store_true",
+ help="Skip exporting tuned files (results JSON only)",
+ )
+ args = parser.parse_args()
+
+ result = optimize(
+ robot_path=args.robot_path,
+ recording_path=args.recording,
+ sigma0=args.sigma0,
+ population_size=args.population_size,
+ max_generations=args.max_generations,
+ sim_dt=args.sim_dt,
+ substeps=args.substeps,
+ pos_weight=args.pos_weight,
+ vel_weight=args.vel_weight,
+ window_duration=args.window_duration,
+ seed=args.seed,
+ )
+
+ # Save results JSON.
+ robot_path = Path(args.robot_path).resolve()
+ result_path = robot_path / "sysid_result.json"
+ # Convert numpy types for JSON serialisation.
+ result_json = {
+ k: v for k, v in result.items() if k != "history"
+ }
+ result_json["history_summary"] = {
+ "first_cost": result["history"][0][1] if result["history"] else None,
+ "final_cost": result["history"][-1][1] if result["history"] else None,
+ "generations": len(result["history"]),
+ }
+ result_path.write_text(json.dumps(result_json, indent=2, default=str))
+ log.info("results_saved", path=str(result_path))
+
+ # Export tuned files unless --no-export.
+ if not args.no_export:
+ from src.sysid.export import export_tuned_files
+
+ export_tuned_files(
+ robot_path=args.robot_path,
+ params=result["best_params"],
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/sysid/rollout.py b/src/sysid/rollout.py
new file mode 100644
index 0000000..cbdcf53
--- /dev/null
+++ b/src/sysid/rollout.py
@@ -0,0 +1,477 @@
+"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
+
+Given a parameter vector and a recorded action sequence, builds a MuJoCo
+model with overridden physics parameters, replays the actions, and returns
+the simulated trajectory for comparison with the real recording.
+
+This module is the inner loop of the CMA-ES optimizer: it is called once
+per candidate parameter vector per generation.
+"""
+
+from __future__ import annotations
+
+import copy
+import dataclasses
+import math
+import tempfile
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any
+
+import mujoco
+import numpy as np
+import yaml
+
+
+# ── Tunable parameter specification ──────────────────────────────────
+
+
+@dataclasses.dataclass
+class ParamSpec:
+ """Specification for a single tunable parameter."""
+
+ name: str
+ default: float
+ lower: float
+ upper: float
+ log_scale: bool = False # optimise in log-space (masses, inertias)
+
+
+# Default parameter specs for the rotary cartpole.
+# Order matters: the optimizer maps a flat vector to these specs.
+ROTARY_CARTPOLE_PARAMS: list[ParamSpec] = [
+ # ── Arm link (URDF) ──────────────────────────────────────────
+ ParamSpec("arm_mass", 0.010, 0.003, 0.05, log_scale=True),
+ ParamSpec("arm_com_x", 0.00005, -0.02, 0.02),
+ ParamSpec("arm_com_y", 0.0065, -0.01, 0.02),
+ ParamSpec("arm_com_z", 0.00563, -0.01, 0.02),
+ # ── Pendulum link (URDF) ─────────────────────────────────────
+ ParamSpec("pendulum_mass", 0.015, 0.005, 0.05, log_scale=True),
+ ParamSpec("pendulum_com_x", 0.1583, 0.05, 0.25),
+ ParamSpec("pendulum_com_y", -0.0983, -0.20, 0.0),
+ ParamSpec("pendulum_com_z", 0.0, -0.05, 0.05),
+ ParamSpec("pendulum_ixx", 6.16e-06, 1e-07, 1e-04, log_scale=True),
+ ParamSpec("pendulum_iyy", 6.16e-06, 1e-07, 1e-04, log_scale=True),
+ ParamSpec("pendulum_izz", 1.23e-05, 1e-07, 1e-04, log_scale=True),
+ ParamSpec("pendulum_ixy", 6.10e-06, -1e-04, 1e-04),
+ # ── Actuator / joint dynamics (robot.yaml) ───────────────────
+ ParamSpec("actuator_gear", 0.064, 0.01, 0.2, log_scale=True),
+ ParamSpec("actuator_filter_tau", 0.03, 0.005, 0.15),
+ ParamSpec("motor_damping", 0.003, 1e-4, 0.05, log_scale=True),
+ ParamSpec("pendulum_damping", 0.0001, 1e-5, 0.01, log_scale=True),
+ ParamSpec("motor_armature", 0.0001, 1e-5, 0.01, log_scale=True),
+ ParamSpec("motor_frictionloss", 0.03, 0.001, 0.1, log_scale=True),
+]
+
+
+def params_to_dict(
+ values: np.ndarray, specs: list[ParamSpec] | None = None
+) -> dict[str, float]:
+ """Convert a flat parameter vector to a named dict."""
+ if specs is None:
+ specs = ROTARY_CARTPOLE_PARAMS
+ return {s.name: float(values[i]) for i, s in enumerate(specs)}
+
+
+def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
+ """Return the default parameter vector (in natural space)."""
+ if specs is None:
+ specs = ROTARY_CARTPOLE_PARAMS
+ return np.array([s.default for s in specs], dtype=np.float64)
+
+
+def bounds_arrays(
+ specs: list[ParamSpec] | None = None,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Return (lower, upper) bound arrays."""
+ if specs is None:
+ specs = ROTARY_CARTPOLE_PARAMS
+ lo = np.array([s.lower for s in specs], dtype=np.float64)
+ hi = np.array([s.upper for s in specs], dtype=np.float64)
+ return lo, hi
+
+
+# ── MuJoCo model building with parameter overrides ──────────────────
+
+
+def _build_model(
+ robot_path: Path,
+ params: dict[str, float],
+) -> mujoco.MjModel:
+ """Build a MuJoCo model from URDF + robot.yaml with parameter overrides.
+
+ Follows the same two-step approach as ``MuJoCoRunner._load_model()``:
+ 1. Parse URDF, inject meshdir, load into MuJoCo
+ 2. Export MJCF, inject actuators + joint overrides + param overrides, reload
+ """
+ robot_path = Path(robot_path).resolve()
+ robot_yaml = yaml.safe_load((robot_path / "robot.yaml").read_text())
+ urdf_path = robot_path / robot_yaml["urdf"]
+
+ # ── Step 1: Load URDF ────────────────────────────────────────
+ tree = ET.parse(urdf_path)
+ root = tree.getroot()
+
+ # Inject meshdir compiler directive.
+ meshdir = None
+ for mesh_el in root.iter("mesh"):
+ fn = mesh_el.get("filename", "")
+ parent = str(Path(fn).parent)
+ if parent and parent != ".":
+ meshdir = parent
+ break
+ if meshdir:
+ mj_ext = ET.SubElement(root, "mujoco")
+ ET.SubElement(
+ mj_ext, "compiler", attrib={"meshdir": meshdir, "balanceinertia": "true"}
+ )
+
+ # Override URDF inertial parameters BEFORE MuJoCo loading.
+ for link in root.iter("link"):
+ link_name = link.get("name", "")
+ inertial = link.find("inertial")
+ if inertial is None:
+ continue
+
+ if link_name == "arm":
+ _set_mass(inertial, params.get("arm_mass"))
+ _set_com(
+ inertial,
+ params.get("arm_com_x"),
+ params.get("arm_com_y"),
+ params.get("arm_com_z"),
+ )
+
+ elif link_name == "pendulum":
+ _set_mass(inertial, params.get("pendulum_mass"))
+ _set_com(
+ inertial,
+ params.get("pendulum_com_x"),
+ params.get("pendulum_com_y"),
+ params.get("pendulum_com_z"),
+ )
+ _set_inertia(
+ inertial,
+ ixx=params.get("pendulum_ixx"),
+ iyy=params.get("pendulum_iyy"),
+ izz=params.get("pendulum_izz"),
+ ixy=params.get("pendulum_ixy"),
+ )
+
+ # Write temp URDF and load.
+ tmp_urdf = robot_path / "_tmp_sysid_load.urdf"
+ tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
+ try:
+ model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
+ finally:
+ tmp_urdf.unlink(missing_ok=True)
+
+ # ── Step 2: Export MJCF, inject actuators + overrides ────────
+ tmp_mjcf = robot_path / "_tmp_sysid_inject.xml"
+ try:
+ mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
+ mjcf_root = ET.fromstring(tmp_mjcf.read_text())
+
+ # Actuator.
+ gear = params.get("actuator_gear", robot_yaml["actuators"][0].get("gear", 0.064))
+ filter_tau = params.get(
+ "actuator_filter_tau",
+ robot_yaml["actuators"][0].get("filter_tau", 0.03),
+ )
+ act_cfg = robot_yaml["actuators"][0]
+ ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
+
+ act_elem = ET.SubElement(mjcf_root, "actuator")
+ attribs: dict[str, str] = {
+ "name": f"{act_cfg['joint']}_motor",
+ "joint": act_cfg["joint"],
+ "gear": str(gear),
+ "ctrlrange": f"{ctrl_lo} {ctrl_hi}",
+ }
+ if filter_tau > 0:
+ attribs["dyntype"] = "filter"
+ attribs["dynprm"] = str(filter_tau)
+ attribs["gaintype"] = "fixed"
+ attribs["biastype"] = "none"
+ ET.SubElement(act_elem, "general", attrib=attribs)
+ else:
+ ET.SubElement(act_elem, "motor", attrib=attribs)
+
+ # Joint overrides.
+ motor_damping = params.get("motor_damping", 0.003)
+ pend_damping = params.get("pendulum_damping", 0.0001)
+ motor_armature = params.get("motor_armature", 0.0001)
+ motor_frictionloss = params.get("motor_frictionloss", 0.03)
+
+ for body in mjcf_root.iter("body"):
+ for jnt in body.findall("joint"):
+ name = jnt.get("name")
+ if name == "motor_joint":
+ jnt.set("damping", str(motor_damping))
+ jnt.set("armature", str(motor_armature))
+ jnt.set("frictionloss", str(motor_frictionloss))
+ elif name == "pendulum_joint":
+ jnt.set("damping", str(pend_damping))
+
+ # Disable self-collision.
+ for geom in mjcf_root.iter("geom"):
+ geom.set("contype", "0")
+ geom.set("conaffinity", "0")
+
+ modified_xml = ET.tostring(mjcf_root, encoding="unicode")
+ tmp_mjcf.write_text(modified_xml)
+ model = mujoco.MjModel.from_xml_path(str(tmp_mjcf))
+ finally:
+ tmp_mjcf.unlink(missing_ok=True)
+
+ return model
+
+
+def _set_mass(inertial: ET.Element, mass: float | None) -> None:
+ if mass is None:
+ return
+ mass_el = inertial.find("mass")
+ if mass_el is not None:
+ mass_el.set("value", str(mass))
+
+
+def _set_com(
+ inertial: ET.Element,
+ x: float | None,
+ y: float | None,
+ z: float | None,
+) -> None:
+ origin = inertial.find("origin")
+ if origin is None:
+ return
+ xyz = origin.get("xyz", "0 0 0").split()
+ if x is not None:
+ xyz[0] = str(x)
+ if y is not None:
+ xyz[1] = str(y)
+ if z is not None:
+ xyz[2] = str(z)
+ origin.set("xyz", " ".join(xyz))
+
+
+def _set_inertia(
+ inertial: ET.Element,
+ ixx: float | None = None,
+ iyy: float | None = None,
+ izz: float | None = None,
+ ixy: float | None = None,
+ iyz: float | None = None,
+ ixz: float | None = None,
+) -> None:
+ ine = inertial.find("inertia")
+ if ine is None:
+ return
+ for attr, val in [
+ ("ixx", ixx), ("iyy", iyy), ("izz", izz),
+ ("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
+ ]:
+ if val is not None:
+ ine.set(attr, str(val))
+
+
+# ── Simulation rollout ───────────────────────────────────────────────
+
+
+def rollout(
+ robot_path: str | Path,
+ params: dict[str, float],
+ actions: np.ndarray,
+ timesteps: np.ndarray,
+ sim_dt: float = 0.002,
+ substeps: int = 10,
+) -> dict[str, np.ndarray]:
+ """Replay recorded actions in MuJoCo with overridden parameters.
+
+ Parameters
+ ----------
+ robot_path : asset directory
+ params : named parameter overrides
+ actions : (N,) normalised actions [-1, 1] from the recording
+ timesteps : (N,) wall-clock times (seconds) from the recording
+ sim_dt : MuJoCo physics timestep
+ substeps : physics substeps per control step
+
+ Returns
+ -------
+ dict with keys: motor_angle, motor_vel, pendulum_angle, pendulum_vel
+ Each is an (N,) numpy array of simulated values.
+ """
+ robot_path = Path(robot_path).resolve()
+ model = _build_model(robot_path, params)
+ model.opt.timestep = sim_dt
+ data = mujoco.MjData(model)
+
+ # Start from pendulum hanging down (qpos=0 is down per URDF convention).
+ mujoco.mj_resetData(model, data)
+
+ # Control dt derived from actual recording sample rate.
+ n = len(actions)
+ ctrl_dt = sim_dt * substeps
+
+ # Pre-allocate output.
+ sim_motor_angle = np.zeros(n, dtype=np.float64)
+ sim_motor_vel = np.zeros(n, dtype=np.float64)
+ sim_pend_angle = np.zeros(n, dtype=np.float64)
+ sim_pend_vel = np.zeros(n, dtype=np.float64)
+
+ # Extract actuator limit info for software limit switch.
+ nu = model.nu
+ if nu > 0:
+ jnt_id = model.actuator_trnid[0, 0]
+ jnt_limited = bool(model.jnt_limited[jnt_id])
+ jnt_lo = model.jnt_range[jnt_id, 0]
+ jnt_hi = model.jnt_range[jnt_id, 1]
+ gear_sign = float(np.sign(model.actuator_gear[0, 0]))
+ else:
+ jnt_limited = False
+ jnt_lo = jnt_hi = gear_sign = 0.0
+
+ for i in range(n):
+ data.ctrl[0] = actions[i]
+
+ for _ in range(substeps):
+ # Software limit switch (mirrors MuJoCoRunner).
+ if jnt_limited and nu > 0:
+ pos = data.qpos[jnt_id]
+ if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
+ data.ctrl[0] = 0.0
+ elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
+ data.ctrl[0] = 0.0
+ mujoco.mj_step(model, data)
+
+ sim_motor_angle[i] = data.qpos[0]
+ sim_motor_vel[i] = data.qvel[0]
+ sim_pend_angle[i] = data.qpos[1]
+ sim_pend_vel[i] = data.qvel[1]
+
+ return {
+ "motor_angle": sim_motor_angle,
+ "motor_vel": sim_motor_vel,
+ "pendulum_angle": sim_pend_angle,
+ "pendulum_vel": sim_pend_vel,
+ }
+
+
+def windowed_rollout(
+ robot_path: str | Path,
+ params: dict[str, float],
+ recording: dict[str, np.ndarray],
+ window_duration: float = 0.5,
+ sim_dt: float = 0.002,
+ substeps: int = 10,
+) -> dict[str, np.ndarray | float]:
+ """Multiple-shooting rollout — split recording into short windows.
+
+ For each window:
+ 1. Initialize MuJoCo state from the real qpos/qvel at the window start.
+ 2. Replay the recorded actions within the window.
+ 3. Record the simulated output.
+
+ This prevents error accumulation across the full trajectory, giving
+ a much smoother cost landscape for the optimizer.
+
+ Parameters
+ ----------
+ robot_path : asset directory
+ params : named parameter overrides
+ recording : dict with keys time, action, motor_angle, motor_vel,
+ pendulum_angle, pendulum_vel (all 1D arrays of length N)
+ window_duration : length of each shooting window in seconds
+ sim_dt : MuJoCo physics timestep
+ substeps : physics substeps per control step
+
+ Returns
+ -------
+ dict with:
+ motor_angle, motor_vel, pendulum_angle, pendulum_vel — (N,) arrays
+ (stitched from per-window simulations)
+ n_windows — number of windows used
+ """
+ robot_path = Path(robot_path).resolve()
+ model = _build_model(robot_path, params)
+ model.opt.timestep = sim_dt
+ data = mujoco.MjData(model)
+
+ times = recording["time"]
+ actions = recording["action"]
+ real_motor = recording["motor_angle"]
+ real_motor_vel = recording["motor_vel"]
+ real_pend = recording["pendulum_angle"]
+ real_pend_vel = recording["pendulum_vel"]
+ n = len(actions)
+
+ # Pre-allocate output (stitched from all windows).
+ sim_motor_angle = np.zeros(n, dtype=np.float64)
+ sim_motor_vel = np.zeros(n, dtype=np.float64)
+ sim_pend_angle = np.zeros(n, dtype=np.float64)
+ sim_pend_vel = np.zeros(n, dtype=np.float64)
+
+ # Extract actuator limit info.
+ nu = model.nu
+ if nu > 0:
+ jnt_id = model.actuator_trnid[0, 0]
+ jnt_limited = bool(model.jnt_limited[jnt_id])
+ jnt_lo = model.jnt_range[jnt_id, 0]
+ jnt_hi = model.jnt_range[jnt_id, 1]
+ gear_sign = float(np.sign(model.actuator_gear[0, 0]))
+ else:
+ jnt_limited = False
+ jnt_lo = jnt_hi = gear_sign = 0.0
+
+ # Compute window boundaries from recording timestamps.
+ t0 = times[0]
+ t_end = times[-1]
+ window_starts: list[int] = [] # indices into the recording
+ current_t = t0
+ while current_t < t_end:
+ # Find the index closest to current_t.
+ idx = int(np.searchsorted(times, current_t))
+ idx = min(idx, n - 1)
+ window_starts.append(idx)
+ current_t += window_duration
+
+ n_windows = len(window_starts)
+
+ for w, w_start in enumerate(window_starts):
+ # Window end: next window start, or end of recording.
+ w_end = window_starts[w + 1] if w + 1 < n_windows else n
+
+ # Initialize MuJoCo state from real data at window start.
+ mujoco.mj_resetData(model, data)
+ data.qpos[0] = real_motor[w_start]
+ data.qpos[1] = real_pend[w_start]
+ data.qvel[0] = real_motor_vel[w_start]
+ data.qvel[1] = real_pend_vel[w_start]
+ data.ctrl[:] = 0.0
+ # Forward kinematics to make state consistent.
+ mujoco.mj_forward(model, data)
+
+ for i in range(w_start, w_end):
+ data.ctrl[0] = actions[i]
+
+ for _ in range(substeps):
+ if jnt_limited and nu > 0:
+ pos = data.qpos[jnt_id]
+ if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
+ data.ctrl[0] = 0.0
+ elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
+ data.ctrl[0] = 0.0
+ mujoco.mj_step(model, data)
+
+ sim_motor_angle[i] = data.qpos[0]
+ sim_motor_vel[i] = data.qvel[0]
+ sim_pend_angle[i] = data.qpos[1]
+ sim_pend_vel[i] = data.qvel[1]
+
+ return {
+ "motor_angle": sim_motor_angle,
+ "motor_vel": sim_motor_vel,
+ "pendulum_angle": sim_pend_angle,
+ "pendulum_vel": sim_pend_vel,
+ "n_windows": n_windows,
+ }
diff --git a/src/sysid/visualize.py b/src/sysid/visualize.py
new file mode 100644
index 0000000..987791c
--- /dev/null
+++ b/src/sysid/visualize.py
@@ -0,0 +1,287 @@
+"""Visualise system identification results — real vs simulated trajectories.
+
+Loads a recording and runs simulation with both the original and tuned
+parameters, then plots a 4-panel comparison (motor angle, motor vel,
+pendulum angle, pendulum vel) over time.
+
+Usage:
+ python -m src.sysid.visualize \
+ --robot-path assets/rotary_cartpole \
+ --recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
+
+ # Also compare with tuned parameters:
+ python -m src.sysid.visualize \
+ --robot-path assets/rotary_cartpole \
+ --recording .npz \
+ --result assets/rotary_cartpole/sysid_result.json
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import math
+from pathlib import Path
+
+import numpy as np
+import structlog
+
+log = structlog.get_logger()
+
+
+def visualize(
+ robot_path: str | Path,
+ recording_path: str | Path,
+ result_path: str | Path | None = None,
+ sim_dt: float = 0.002,
+ substeps: int = 10,
+ window_duration: float = 0.5,
+ save_path: str | Path | None = None,
+ show: bool = True,
+) -> None:
+ """Generate comparison plot.
+
+ Parameters
+ ----------
+ robot_path : robot asset directory
+ recording_path : .npz file from capture
+ result_path : sysid_result.json with best_params (optional)
+ sim_dt / substeps : physics settings for rollout
+ window_duration : shooting window length (s); 0 = open-loop
+ save_path : if provided, save figure to this path (PNG, PDF, …)
+ show : if True, display interactive matplotlib window
+ """
+ import matplotlib.pyplot as plt
+
+ from src.sysid.rollout import (
+ ROTARY_CARTPOLE_PARAMS,
+ defaults_vector,
+ params_to_dict,
+ rollout,
+ windowed_rollout,
+ )
+
+ robot_path = Path(robot_path).resolve()
+ recording = dict(np.load(recording_path))
+
+ t = recording["time"]
+ actions = recording["action"]
+
+ # ── Simulate with default parameters ─────────────────────────
+ default_params = params_to_dict(
+ defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
+ )
+ log.info("simulating_default_params", windowed=window_duration > 0)
+ if window_duration > 0:
+ sim_default = windowed_rollout(
+ robot_path=robot_path,
+ params=default_params,
+ recording=recording,
+ window_duration=window_duration,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ else:
+ sim_default = rollout(
+ robot_path=robot_path,
+ params=default_params,
+ actions=actions,
+ timesteps=t,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+
+ # ── Simulate with tuned parameters (if available) ────────────
+ sim_tuned = None
+ tuned_cost = None
+ if result_path is not None:
+ result_path = Path(result_path)
+ if result_path.exists():
+ result = json.loads(result_path.read_text())
+ tuned_params = result.get("best_params", {})
+ tuned_cost = result.get("best_cost")
+ log.info("simulating_tuned_params", cost=tuned_cost)
+ if window_duration > 0:
+ sim_tuned = windowed_rollout(
+ robot_path=robot_path,
+ params=tuned_params,
+ recording=recording,
+ window_duration=window_duration,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ else:
+ sim_tuned = rollout(
+ robot_path=robot_path,
+ params=tuned_params,
+ actions=actions,
+ timesteps=t,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ else:
+ log.warning("result_file_not_found", path=str(result_path))
+ else:
+ # Auto-detect sysid_result.json in robot_path.
+ auto_result = robot_path / "sysid_result.json"
+ if auto_result.exists():
+ result = json.loads(auto_result.read_text())
+ tuned_params = result.get("best_params", {})
+ tuned_cost = result.get("best_cost")
+ log.info("auto_detected_tuned_params", cost=tuned_cost)
+ if window_duration > 0:
+ sim_tuned = windowed_rollout(
+ robot_path=robot_path,
+ params=tuned_params,
+ recording=recording,
+ window_duration=window_duration,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+ else:
+ sim_tuned = rollout(
+ robot_path=robot_path,
+ params=tuned_params,
+ actions=actions,
+ timesteps=t,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ )
+
+ # ── Plot ─────────────────────────────────────────────────────
+ fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
+
+ channels = [
+ ("motor_angle", "Motor Angle (rad)", True),
+ ("motor_vel", "Motor Velocity (rad/s)", False),
+ ("pendulum_angle", "Pendulum Angle (rad)", True),
+ ("pendulum_vel", "Pendulum Velocity (rad/s)", False),
+ ]
+
+ for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
+ real = recording[key]
+
+ ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
+ ax.plot(
+ t,
+ sim_default[key],
+ "--",
+ color="#d62728",
+ linewidth=1.0,
+ alpha=0.7,
+ label="Sim (original)",
+ )
+ if sim_tuned is not None:
+ ax.plot(
+ t,
+ sim_tuned[key],
+ "--",
+ color="#2ca02c",
+ linewidth=1.0,
+ alpha=0.7,
+ label="Sim (tuned)",
+ )
+
+ ax.set_ylabel(ylabel)
+ ax.legend(loc="upper right", fontsize=8)
+ ax.grid(True, alpha=0.3)
+
+ # Action plot (bottom panel).
+ axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
+ axes[4].set_ylabel("Action (norm)")
+ axes[4].set_xlabel("Time (s)")
+ axes[4].grid(True, alpha=0.3)
+ axes[4].set_ylim(-1.1, 1.1)
+
+ # Title with cost info.
+ title = "System Identification — Real vs Simulated Trajectories"
+ if tuned_cost is not None:
+ # Compute original cost for comparison.
+ from src.sysid.optimize import cost_function
+
+ orig_cost = cost_function(
+ defaults_vector(ROTARY_CARTPOLE_PARAMS),
+ recording,
+ robot_path,
+ ROTARY_CARTPOLE_PARAMS,
+ sim_dt=sim_dt,
+ substeps=substeps,
+ window_duration=window_duration,
+ )
+ title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
+ improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
+ title += f" ({improvement:+.1f}%)"
+
+ fig.suptitle(title, fontsize=12)
+ plt.tight_layout()
+
+ if save_path:
+ save_path = Path(save_path)
+ fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
+ log.info("figure_saved", path=str(save_path))
+
+ if show:
+ plt.show()
+ else:
+ plt.close(fig)
+
+
+# ── CLI entry point ──────────────────────────────────────────────────
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Visualise system identification results."
+ )
+ parser.add_argument(
+ "--robot-path",
+ type=str,
+ default="assets/rotary_cartpole",
+ )
+ parser.add_argument(
+ "--recording",
+ type=str,
+ required=True,
+ help="Path to .npz recording file",
+ )
+ parser.add_argument(
+ "--result",
+ type=str,
+ default=None,
+ help="Path to sysid_result.json (auto-detected if omitted)",
+ )
+ parser.add_argument("--sim-dt", type=float, default=0.002)
+ parser.add_argument("--substeps", type=int, default=10)
+ parser.add_argument(
+ "--window-duration",
+ type=float,
+ default=0.5,
+ help="Shooting window length in seconds (0 = open-loop)",
+ )
+ parser.add_argument(
+ "--save",
+ type=str,
+ default=None,
+ help="Save figure to this path (PNG, PDF, …)",
+ )
+ parser.add_argument(
+ "--no-show",
+ action="store_true",
+ help="Don't show interactive window (useful for CI)",
+ )
+ args = parser.parse_args()
+
+ visualize(
+ robot_path=args.robot_path,
+ recording_path=args.recording,
+ result_path=args.result,
+ sim_dt=args.sim_dt,
+ substeps=args.substeps,
+ window_duration=args.window_duration,
+ save_path=args.save,
+ show=not args.no_show,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/training/trainer.py b/src/training/trainer.py
index ce66e5f..72d0202 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -35,6 +35,11 @@ class TrainerConfig:
hidden_sizes: tuple[int, ...] = (64, 64)
+ # Policy
+ initial_log_std: float = 0.5 # initial exploration noise
+ min_log_std: float = -2.0 # minimum exploration noise
+ max_log_std: float = 2.0 # maximum exploration noise (2.0 ≈ σ=7.4)
+
# Training
total_timesteps: int = 1_000_000
log_interval: int = 10
@@ -110,6 +115,7 @@ class VideoRecordingTrainer(SequentialTrainer):
return self._tcfg.record_video_fps
dt = getattr(self.env.config, "dt", 0.02)
substeps = getattr(self.env.config, "substeps", 1)
+ # SerialRunner has dt but no substeps — dt *is* the control period.
return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None:
@@ -181,8 +187,9 @@ class Trainer:
action_space=act_space,
device=device,
hidden_sizes=self.config.hidden_sizes,
- initial_log_std=0.5,
- min_log_std=-2.0,
+ initial_log_std=self.config.initial_log_std,
+ min_log_std=self.config.min_log_std,
+ max_log_std=self.config.max_log_std,
)
models = {"policy": self.model, "value": self.model}
diff --git a/train.py b/train.py
index 124c398..6547e38 100644
--- a/train.py
+++ b/train.py
@@ -26,8 +26,10 @@ logger = structlog.get_logger()
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
- "mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
- "mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
+ "mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
+ "mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
+ "mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
+ "serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
}
@@ -94,6 +96,7 @@ def main(cfg: DictConfig) -> None:
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
+ training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole")
diff --git a/viz.py b/viz.py
index 59d937e..948afad 100644
--- a/viz.py
+++ b/viz.py
@@ -1,9 +1,13 @@
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
-Usage:
+Usage (simulation):
mjpython viz.py env=rotary_cartpole
mjpython viz.py env=cartpole +com=true
+Usage (real hardware — digital twin):
+ mjpython viz.py env=rotary_cartpole runner=serial
+ mjpython viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
+
Controls:
Left/Right arrows — apply torque to first actuator
R — reset environment
@@ -15,6 +19,7 @@ import time
import hydra
import mujoco
import mujoco.viewer
+import numpy as np
import structlog
import torch
from hydra.core.hydra_config import HydraConfig
@@ -45,10 +50,64 @@ def _key_callback(keycode: int) -> None:
_reset_flag[0] = True
+def _add_action_arrow(viewer, model, data, action_val: float) -> None:
+ """Draw an arrow on the motor joint showing applied torque direction."""
+ if abs(action_val) < 0.01 or model.nu == 0:
+ return
+
+ # Get the body that the first actuator's joint belongs to
+ jnt_id = model.actuator_trnid[0, 0]
+ body_id = model.jnt_bodyid[jnt_id]
+
+ # Arrow origin: body position
+ pos = data.xpos[body_id].copy()
+ pos[2] += 0.02 # lift slightly above the body
+
+ # Arrow direction: along joint axis in world frame, scaled by action
+ axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
+ arrow_len = 0.08 * action_val
+ direction = axis * np.sign(arrow_len)
+
+ # Build rotation matrix: arrow rendered along local z-axis
+ z = direction / (np.linalg.norm(direction) + 1e-8)
+ up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
+ x = np.cross(up, z)
+ x /= np.linalg.norm(x) + 1e-8
+ y = np.cross(z, x)
+ mat = np.column_stack([x, y, z]).flatten()
+
+ # Color: green = positive, red = negative
+ rgba = np.array(
+ [0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
+ dtype=np.float32,
+ )
+
+ geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
+ mujoco.mjv_initGeom(
+ geom,
+ type=mujoco.mjtGeom.mjGEOM_ARROW,
+ size=np.array([0.008, 0.008, abs(arrow_len)]),
+ pos=pos,
+ mat=mat,
+ rgba=rgba,
+ )
+ viewer.user_scn.ngeom += 1
+
+
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
+ runner_name = choices.get("runner", "mujoco")
+
+ if runner_name == "serial":
+ _main_serial(cfg, env_name)
+ else:
+ _main_sim(cfg, env_name)
+
+
+def _main_sim(cfg: DictConfig, env_name: str) -> None:
+ """Simulation visualization — step MuJoCo physics with keyboard control."""
# Build env + runner (single env for viz)
env = build_env(env_name, cfg)
@@ -94,8 +153,10 @@ def main(cfg: DictConfig) -> None:
action = torch.tensor([[action_val]])
obs, reward, terminated, truncated, info = runner.step(action)
- # Sync viewer
+ # Sync viewer with action arrow overlay
mujoco.mj_forward(model, data)
+ viewer.user_scn.ngeom = 0 # clear previous frame's overlays
+ _add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Print state
@@ -112,5 +173,75 @@ def main(cfg: DictConfig) -> None:
runner.close()
+def _main_serial(cfg: DictConfig, env_name: str) -> None:
+ """Digital-twin visualization — mirror real hardware in MuJoCo viewer.
+
+ The MuJoCo model is loaded for rendering only. Joint positions are
+ read from the ESP32 over serial and applied to the model each frame.
+ Keyboard arrows send motor commands to the real robot.
+ """
+ from src.runners.serial import SerialRunner, SerialRunnerConfig
+
+ env = build_env(env_name, cfg)
+ runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
+ serial_runner = SerialRunner(
+ env=env, config=SerialRunnerConfig(**runner_dict)
+ )
+
+ # Load MuJoCo model for visualisation (same URDF the sim uses).
+ serial_runner._ensure_viz_model()
+ model = serial_runner._viz_model
+ data = serial_runner._viz_data
+
+ with mujoco.viewer.launch_passive(
+ model, data, key_callback=_key_callback
+ ) as viewer:
+ # Show CoM / inertia if requested.
+ show_com = cfg.get("com", False)
+ if show_com:
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
+ viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
+
+ logger.info(
+ "viewer_started",
+ env=env_name,
+ mode="serial (digital twin)",
+ port=serial_runner.config.port,
+ controls="Left/Right arrows = motor command, R = reset",
+ )
+
+ while viewer.is_running():
+ # Read action from keyboard callback.
+ if time.time() - _action_time[0] < _ACTION_HOLD_S:
+ action_val = _action_val[0]
+ else:
+ action_val = 0.0
+
+ # Reset on R press.
+ if _reset_flag[0]:
+ _reset_flag[0] = False
+ serial_runner._send("M0")
+ serial_runner._drive_to_center()
+ serial_runner._wait_for_pendulum_still()
+ logger.info("reset (drive-to-center + settle)")
+
+ # Send motor command to real hardware.
+ motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
+ serial_runner._send(f"M{motor_speed}")
+
+ # Sync MuJoCo model with real sensor data.
+ serial_runner._sync_viz()
+
+ # Render overlays and sync viewer.
+ viewer.user_scn.ngeom = 0
+ _add_action_arrow(viewer, model, data, action_val)
+ viewer.sync()
+
+ # Real-time pacing (~50 Hz, matches serial dt).
+ time.sleep(serial_runner.config.dt)
+
+ serial_runner.close()
+
+
if __name__ == "__main__":
main()