24 Commits

Author SHA1 Message Date
1e0836e1bc ♻️ full agent refactor 2026-06-10 21:15:34 +02:00
a98e86ef66 disable JAX GPU preallocation so MJX shares VRAM with torch 2026-06-10 19:48:48 +02:00
4210b6cb53 jax[cuda12] on linux for GPU; EGL headless render; non-fatal video 2026-06-10 19:25:33 +02:00
a6fbde798a pin skrl/jax/mujoco/gymnasium versions; custom CUDA base image 2026-06-10 09:05:48 +02:00
56499ebe97 feat: full DR (friction/damping/torque) in MJX JIT step 2026-06-09 21:25:05 +02:00
b37cd26690 feat: sim2real domain randomization + reward fixes for rotary cartpole
Close the sim2real gap for the Furuta pendulum (swings up but can't
balance on hardware). Root causes were (a) no domain randomization, so
the policy overfit one deterministic sim instance, and (b) reward design
flaws that produced degenerate policies.

Domain randomization (runner-level, backend-agnostic):
- BaseRunner: domain_rand config; per-env action-delay buffer (latency),
  Gaussian qpos/qvel sensor noise, per-env dynamics-scale sampling
  (friction/damping/torque), resampled per episode. Sensor noise per step.
- privileged_obs/privileged_dim expose normalized DR factors (mu) for RMA.
- step() now uses clean state for reward/termination, noisy state for the
  observation the policy sees.
- MuJoCoRunner: applies per-env friction/damping/torque scales.
- robot.py: compute_motor_force gains friction/damping scale args.
- Configs: DR blocks for mujoco (full) and mjx (delay+noise); clean
  defaults for mujoco_single/serial; noise/delay anchored to recordings.

Reward fixes (rotary_cartpole):
- Shift upright reward to [0,1] (was [-1,1]) + alive_bonus, so surviving
  always beats ending early (kills the "suicide into the limit" policy).
- Add balance_bonus * upright * stillness so reward requires upright AND
  near-zero pendulum velocity (kills the "spin in full loops" policy).

Deploy:
- eval.py load_policy reconstructs the history/adaptation encoder
  (auto-detects its dim from the checkpoint) so DR+embedding policies load.

Fixes:
- MuJoCoRunner._sim_reset referenced self._env (typo) -> self.env, which
  was breaking every rotary-cartpole reset.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 20:48:25 +02:00
8cc84d6a21 feat: RMA-style history-conditioned policy for sim2real adaptation
Added a temporal observation history buffer and 1D-CNN encoder so the
policy can implicitly infer environment parameters (mass, friction,
gear ratios, etc.) from recent (obs, action) dynamics.

Architecture:
  history window [(obs₀,a₀), ..., (obs_{H-1},a_{H-1})]
      → 1D-CNN HistoryEncoder → embedding (32-dim)
      → concat [current_obs, embedding] → MLP → action

Components:
- BaseRunner: history ring buffer, _push_history/_reset_history,
  augmented obs space (6 + H×7 = 76 with H=10)
- HistoryEncoder (src/models/mlp.py): 2-layer temporal Conv1d + GAP
- SharedMLP: optional history_length/raw_obs_dim/embedding_dim params;
  splits augmented obs, encodes history, feeds [obs, emb] to MLP
- TrainerConfig: history_length, embedding_dim fields
- All runner configs: history_length=10 by default
- Tests: encoder shape, model with/without history, config defaults
2026-03-28 18:58:24 +01:00
8ed9afe583 chore: update robot.yaml with unified sysid cost 0.925
All 28 params tuned jointly. Now includes stribeck_friction_boost,
stribeck_vel, action_bias. Points to rotary_cartpole_tuned.urdf.
2026-03-28 18:46:45 +01:00
5880997786 refactor: merge motor sysid into unified sysid module
Unified the two separate sysid codepaths (motor-only and full-system)
into a single module that optimizes all 28 parameters jointly:

- 13 motor params (asymmetric gear, damping, friction, deadzone,
  Stribeck boost, action bias, filter tau, armature, ctrl_limit)
- 15 pendulum/arm params (mass, CoM, inertia, joint dynamics)

Key changes:
- Added stribeck_friction_boost, stribeck_vel, action_bias to
  ActuatorConfig (robot.py) and MJX runner
- Created shared src/sysid/preprocess.py (SG velocity recomputation)
- Rewrote src/sysid/rollout.py with unified MOTOR_PARAMS + PENDULUM_PARAMS
  spec and PARAM_SETS dict for flexible subset optimization
- Updated optimize.py, export.py, visualize.py to use unified params
  (removed all LOCKED_MOTOR_PARAMS references)
- Removed src/sysid/motor/ module and scripts/motor_sysid.py

Net: -1383 lines, zero code duplication between motor and full-system sysid.
2026-03-28 16:48:22 +01:00
ca0e7b8b03 clean up lot of stuff 2026-03-22 15:49:13 +01:00
d3ed1c25ad ⚗️ experimenting training runs 2026-03-12 00:38:09 +01:00
3b2d6d08f9 update hpo 2026-03-11 23:28:39 +01:00
23801857f4 ♻️ cleanup 2026-03-11 23:16:42 +01:00
3db68255f0 update registry 2026-03-11 23:11:21 +01:00
1a822bd82e 🐛 bug fixes 2026-03-11 23:07:37 +01:00
4115447022 ♻️ crazy refactor 2026-03-11 22:52:01 +01:00
35223b3560 update motor friction 2026-03-09 23:37:10 +01:00
0f13086fee remove custom ema and use mujoco motor control 2026-03-09 22:47:57 +01:00
9813319275 add limit enforce to mujoco for joints 2026-03-09 22:30:48 +01:00
70cd2cdd7d better robot joint loading 2026-03-09 22:17:28 +01:00
9be07d9186 add new ppo mjx config 2026-03-09 21:33:42 +01:00
26ccb1e902 add mjx runner 2026-03-09 21:18:19 +01:00
15da0ef2fd update urdf and dependencies 2026-03-09 20:39:02 +01:00
c753c369b4 add rotary cartpole env 2026-03-08 22:58:32 +01:00
66 changed files with 7480 additions and 451 deletions

29
.gitignore vendored
View File

@@ -1,3 +1,28 @@
outputs/
# IDE / editor
.vscode/
runs/
# Training & HPO outputs
outputs/
runs/
smac3_output/
training_log.txt
.pytest_cache/
# Real-robot capture data (large .npz recordings)
assets/**/recordings/
# MuJoCo
MUJOCO_LOG.TXT
# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
.eggs/
dist/
build/
# Temp files
*.stl
*.scad

64
README.md Normal file
View File

@@ -0,0 +1,64 @@
# RL-Framework
A small, fast RL framework for training sim2real policies on a 3D-printed
rotary (Furuta) cartpole — built to scale from a laptop CPU to a GPU worker
(ClearML) without code changes, and to grow into more robots and simulators.
## Architecture
Three orthogonal pieces, composed by Hydra config groups:
| Piece | Role | Implementations |
|---|---|---|
| **Env** (`src/envs/`) | Task logic: obs / reward / termination / init distribution. Pure torch, batched, backend-agnostic. | `rotary_cartpole` |
| **Runner** (`src/runners/`) | Physics + sim2real plumbing (DR, sensor noise, action delay, history buffer). | `mujoco` (CPU), `mjx` (GPU/JAX), `serial` (real ESP32 robot) |
| **Trainer** (`src/training/`) | skrl PPO + shared MLP with optional history encoder. | `ppo`, `ppo_mjx`, `ppo_single`, `ppo_real` |
The robot itself is described once in `assets/<robot>/robot.yaml`
(URDF + identified motor model) and shared by **training, sysid and
deployment** — the motor model (bias → deadzone → gear compensation,
Coulomb + Stribeck friction, viscous damping, first-order lag) is
implemented in `src/core/robot.py` and mirrored exactly in the MJX JIT
step (`src/runners/mjx.py`).
## Train
```bash
# CPU (64 parallel MuJoCo envs)
python scripts/train.py env=rotary_cartpole runner=mujoco training=ppo
# GPU (1024 MJX envs) — local
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx
# GPU — remote on ClearML gpu-queue
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx training.remote=true
```
Videos and scalars stream to ClearML. Checkpoints land in `runs/`.
## Sim2real recipe
1. **Capture** real trajectories: `python -m src.sysid.capture` (writes `.npz` to `assets/<robot>/recordings/`).
2. **Identify** physics: `python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz`
— CMA-ES fits inertials/joint dynamics against the recording (motor model is locked from the unified sysid). Writes `sysid_result.json` + `robot_tuned.yaml` + `*_tuned.urdf`.
3. **Validate** the fit: `python -m src.sysid.visualize`, then copy `robot_tuned.yaml``robot.yaml`.
4. **Train with DR + history**: the runner randomizes friction/damping/torque scales, sensor noise and action latency per episode (`configs/runner/mjx.yaml: domain_rand`), and appends a 10-step (obs, action) history to the observation so the policy can implicitly identify the current dynamics (`history_length`).
5. **Deploy**: `mjpython scripts/eval.py env=rotary_cartpole runner=serial checkpoint=runs/<run>/checkpoints/agent_X.pt`
## Other tools
```bash
mjpython scripts/viz.py env=rotary_cartpole # keyboard-drive the sim
mjpython scripts/viz.py runner=serial # digital twin of the real robot
python scripts/hpo.py env=rotary_cartpole training=ppo_single # ClearML + SMAC3 HPO
pytest tests/ # unit tests
```
## Adding a robot / simulator
- **Robot**: drop `assets/<name>/` (URDF + `robot.yaml`), subclass `BaseEnv`
(obs/reward/termination/`initial_state_ranges`), register in `src/core/registry.py`,
add `configs/env/<name>.yaml`.
- **Simulator**: subclass `BaseRunner` and implement `_sim_initialize`,
`_sim_step`, `_sim_reset` (full-batch return) — DR, history and the
env-side logic come for free. Register in `scripts/train.py: RUNNER_REGISTRY`.

View File

@@ -1,64 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<robot name="cartpole">
<!-- World link (fixed base) -->
<link name="world"/>
<!-- Cart (slides along x-axis) -->
<link name="cart">
<inertial>
<mass value="1.0"/>
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
</inertial>
<visual>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</visual>
<collision>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</collision>
</link>
<!-- Cart slides along x-axis -->
<joint name="cart_joint" type="prismatic">
<parent link="world"/>
<child link="cart"/>
<axis xyz="1 0 0"/>
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
</joint>
<!-- Pole (rotates around y-axis, attached on top of cart) -->
<link name="pole">
<inertial>
<origin xyz="0 0 0.3"/>
<mass value="0.1"/>
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
</inertial>
<visual>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</visual>
<collision>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</collision>
</link>
<!-- Pole rotates freely (no motor) -->
<joint name="pole_joint" type="revolute">
<parent link="cart"/>
<child link="pole"/>
<origin xyz="0 0 0.05"/>
<axis xyz="0 1 0"/>
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
<dynamics damping="0.0" friction="0.0"/>
</joint>
</robot>

View File

@@ -0,0 +1,10 @@
# Motor-only hardware config
# Encoder and motor constants for the motor-only sysid capture.
encoder:
ppr: 11 # pulses per revolution (before quadrature)
gear_ratio: 30.0 # gearbox ratio
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
motor:
max_pwm: 255 # maximum PWM command accepted by firmware

40
assets/motor/motor.xml Normal file
View File

@@ -0,0 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco model="motor_sysid">
<compiler angle="radian" autolimits="true"/>
<option timestep="0.002" integrator="Euler"/>
<worldbody>
<light pos="0 0 1" dir="0 0 -1"/>
<!-- Fixed base (gearbox housing) -->
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921"
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
<!-- Arm: rotates around motor_joint (z-axis) -->
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1"
range="-1.5708 1.5708"
damping="0.001" armature="0.0001" frictionloss="0.03"/>
<!-- Rotor disk: mass is tunable by the optimizer -->
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
contype="0" conaffinity="0"/>
<!-- Arm load: lightweight arm attached to motor shaft -->
<geom name="arm_load" type="capsule" size="0.004"
fromto="0 0 0 -0.014 0.002 0.016"
mass="0.021" rgba="0.8 0.3 0.1 1"
contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint"
gear="0.064"
ctrllimited="true" ctrlrange="-1 1"
dyntype="filter" dynprm="0.03"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Motor-only model WITHOUT arm/pendulum load.
Use for arm-off sysid captures. arm_load mass set to ~0. -->
<mujoco model="motor_sysid_bare">
<compiler angle="radian" autolimits="true"/>
<option timestep="0.002" integrator="Euler"/>
<worldbody>
<light pos="0 0 1" dir="0 0 -1"/>
<!-- Fixed base (gearbox housing) -->
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921"
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
<!-- Arm: rotates around motor_joint (z-axis) -->
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1"
range="-1.5708 1.5708"
damping="0.001" armature="0.0001" frictionloss="0.03"/>
<!-- Rotor disk: mass is tunable by the optimizer -->
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
contype="0" conaffinity="0"/>
<!-- Arm load: near-zero mass for bare-shaft sysid -->
<geom name="arm_load" type="capsule" size="0.004"
fromto="0 0 0 -0.014 0.002 0.016"
mass="0.001" rgba="0.8 0.3 0.1 0.3"
contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint"
gear="0.064"
ctrllimited="true" ctrlrange="-1 1"
dyntype="filter" dynprm="0.03"/>
</actuator>
</mujoco>

Binary file not shown.

After

Width:  |  Height:  |  Size: 463 KiB

View File

@@ -0,0 +1,67 @@
{
"best_params": {
"actuator_gear_pos": 0.3711939014035462,
"actuator_gear_neg": 0.42814281188601877,
"actuator_filter_tau": 0.022300731564787457,
"motor_damping_pos": 0.0013836218905629106,
"motor_damping_neg": 0.005196351489379768,
"motor_armature": 0.0027534181478216656,
"motor_frictionloss_pos": 0.03674406439012955,
"motor_frictionloss_neg": 0.06908200024786905,
"viscous_quadratic": 0.000958226218765762,
"back_emf_gain": 0.0036492272912788297,
"stribeck_friction_boost": 0.044748043677129666,
"stribeck_vel": 4.0513395945623705,
"rotor_mass": 0.03982640764507874,
"motor_deadzone_pos": 0.14181963932762467,
"motor_deadzone_neg": 0.031454276545010214,
"action_bias": -0.007362969452509152,
"gearbox_backlash": 1.4749880999407965e-09
},
"best_cost": 0.21167928018839952,
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/motor/recordings/motor_from_cartpole_162432.npz",
"param_names": [
"actuator_gear_pos",
"actuator_gear_neg",
"actuator_filter_tau",
"motor_damping_pos",
"motor_damping_neg",
"motor_armature",
"motor_frictionloss_pos",
"motor_frictionloss_neg",
"viscous_quadratic",
"back_emf_gain",
"stribeck_friction_boost",
"stribeck_vel",
"rotor_mass",
"motor_deadzone_pos",
"motor_deadzone_neg",
"action_bias",
"gearbox_backlash"
],
"defaults": {
"actuator_gear_pos": 0.064,
"actuator_gear_neg": 0.064,
"actuator_filter_tau": 0.03,
"motor_damping_pos": 0.003,
"motor_damping_neg": 0.003,
"motor_armature": 0.0001,
"motor_frictionloss_pos": 0.03,
"motor_frictionloss_neg": 0.03,
"viscous_quadratic": 0.0,
"back_emf_gain": 0.0,
"stribeck_friction_boost": 0.0,
"stribeck_vel": 2.0,
"rotor_mass": 0.012,
"motor_deadzone_pos": 0.08,
"motor_deadzone_neg": 0.08,
"action_bias": 0.0,
"gearbox_backlash": 0.0
},
"timestamp": "2026-03-23T20:50:53.648753",
"history_summary": {
"first_cost": 5.105499579419285,
"final_cost": 0.21167928018839952,
"generations": 500
}
}

View File

@@ -0,0 +1,19 @@
<?xml version='1.0' encoding='utf-8'?>
<mujoco model="motor_sysid">
<compiler angle="radian" autolimits="true" />
<option timestep="0.002" integrator="Euler" />
<worldbody>
<light pos="0 0 1" dir="0 0 -1" />
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921" rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0" />
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1" range="-1.5708 1.5708" damping="0" armature="0.0027534181478216656" frictionloss="0" />
<geom name="rotor_disk" type="cylinder" size="0.008 0.004" mass="0.03982640764507874" pos="0 0 0" rgba="0.6 0.6 0.6 1" contype="0" conaffinity="0" />
<geom name="arm_load" type="capsule" size="0.004" fromto="0 0 0 -0.014 0.002 0.016" mass="0.021" rgba="0.8 0.3 0.1 1" contype="0" conaffinity="0" />
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint" gear="0.3996683566447825" ctrllimited="true" ctrlrange="-1 1" dyntype="filter" dynprm="0.022300731564787457" />
</actuator>
</mujoco>

6
assets/motor/robot.yaml Normal file
View File

@@ -0,0 +1,6 @@
# Motor-only sysid config
# Minimal config for the motor-only identification pipeline.
# The optimizer patches motor.xml in-memory; this file tells it
# which MJCF to load and provides encoder/hardware constants.
mjcf: motor.xml

View File

@@ -0,0 +1,4 @@
# Motor-only sysid config — bare shaft (no arm/pendulum load)
# Use with arm-off captures to identify pure motor dynamics.
mjcf: motor_bare.xml

View File

@@ -0,0 +1,23 @@
# Tuned motor config — generated by src.sysid.motor.optimize
# Original: robot.yaml
mjcf: motor_tuned.xml
joints:
motor_joint:
armature: 0.002753
frictionloss: 0.052913
hardware_realism:
actuator_gear_pos: 0.371194
actuator_gear_neg: 0.428143
motor_damping_pos: 0.001384
motor_damping_neg: 0.005196
motor_frictionloss_pos: 0.036744
motor_frictionloss_neg: 0.069082
motor_deadzone_pos: 0.14182
motor_deadzone_neg: 0.031454
action_bias: -0.007363
viscous_quadratic: 0.000958
back_emf_gain: 0.003649
stribeck_friction_boost: 0.044748
stribeck_vel: 4.05134
gearbox_backlash: 0.0

View File

@@ -0,0 +1,23 @@
# Rotary cartpole (Furuta pendulum) — real hardware config.
# Describes the physical device for the SerialRunner.
# Robot-specific constants that don't belong in the runner config
# (which is machine-specific: port, baud) or the env config
# (which is task-specific: rewards, max_steps).
encoder:
ppr: 11 # pulses per revolution (before quadrature)
gear_ratio: 30.0 # gearbox ratio
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
safety:
max_motor_angle_deg: 90.0 # hard termination limit (physical endstop ~70-80°)
soft_limit_deg: 40.0 # progressive penalty ramp starts here
reset:
drive_speed: 80 # PWM magnitude for bang-bang drive-to-center
deadband: 15 # encoder count threshold to consider "centered"
drive_timeout: 3.0 # seconds before giving up on drive-to-center
settle_angle_deg: 2.0 # pendulum angle threshold for "still" (degrees)
settle_vel_dps: 5.0 # pendulum velocity threshold (deg/s)
settle_duration: 0.5 # how long pendulum must stay still (seconds)
settle_timeout: 30.0 # give up waiting after this (seconds)

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,30 @@
# Canonical training model — unified sysid (cost 0.925, 475 generations).
# Source: sysid_result.json → exported via src.sysid.export.
# Key physics: ~96 ms motor lag (filter_tau), Stribeck friction, driver bias.
# Regenerate with:
# python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz
# then copy robot_tuned.yaml over this file once validated
# (python -m src.sysid.visualize to compare real vs sim).
urdf: rotary_cartpole_tuned.urdf
actuators:
- joint: motor_joint
type: motor
gear: [0.846499, 1.183733] # torque constant [pos, neg]
ctrl_range: [-0.686251, 0.686251] # PWM saturation (MAX_MOTOR_SPEED / 255)
deadzone: [0.181097, 0.202072] # L298N min |ctrl| for torque [pos, neg]
damping: [0.013165, 0.015452] # viscous damping [pos, neg]
frictionloss: [0.014244, 0.001005] # Coulomb friction [pos, neg]
filter_tau: 0.096263 # 1st-order actuator lag (s) — dominant!
stribeck_friction_boost: 0.068594 # extra static friction near standstill
stribeck_vel: 5.279594 # Stribeck decay velocity (rad/s)
action_bias: 0.056566 # additive ctrl bias (driver asymmetry)
joints:
motor_joint:
armature: 0.001676 # reflected rotor inertia (kg·m²)
frictionloss: 0.0 # handled by motor model via qfrc_applied
pendulum_joint:
damping: 1.2e-05
frictionloss: 7.2e-05

View File

@@ -0,0 +1,34 @@
# Tuned robot config — generated by src.sysid.optimize
# Original: robot.yaml
# Run `python -m src.sysid.visualize` to compare real vs sim.
urdf: rotary_cartpole_tuned.urdf
actuators:
- joint: motor_joint
type: motor
gear:
- 0.846499
- 1.183733
ctrl_range:
- -0.686251
- 0.686251
deadzone:
- 0.181097
- 0.202072
damping:
- 0.013165
- 0.015452
frictionloss:
- 0.014244
- 0.001005
filter_tau: 0.096263
stribeck_friction_boost: 0.068594
stribeck_vel: 5.279594
action_bias: 0.056566
joints:
motor_joint:
armature: 0.001676
frictionloss: 0.0
pendulum_joint:
damping: 1.2e-05
frictionloss: 7.2e-05

View File

@@ -0,0 +1,80 @@
<?xml version='1.0' encoding='utf-8'?>
<robot name="rotary_cartpole">
<link name="world" />
<link name="base_link">
<inertial>
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
<mass value="0.921" />
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
</inertial>
<visual>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="base_joint" type="fixed">
<parent link="world" />
<child link="base_link" />
</joint>
<link name="arm">
<inertial>
<origin xyz="-0.0071030505291264975 0.0008511826488989179 0.007952020186701035" rpy="0 0 0" />
<mass value="0.02110029934220782" />
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="motor_joint" type="revolute">
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
<parent link="base_link" />
<child link="arm" />
<axis xyz="0 0 1" />
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
<dynamics damping="0.001" />
</joint>
<link name="pendulum">
<inertial>
<origin xyz="0.060245187591695615 -0.07601707109312682 -0.0034636702158137786" rpy="0 0 0" />
<mass value="0.03936742845036306" />
<inertia ixx="6.202768755990066e-05" iyy="3.70078470430685e-05" izz="7.827356811788924e-05" ixy="-6.925117819616428e-06" iyz="0.0" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="pendulum_joint" type="continuous">
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
<parent link="arm" />
<child link="pendulum" />
<axis xyz="0 -1 0" />
<dynamics damping="0.0001" />
</joint>
</robot>

View File

@@ -0,0 +1,80 @@
<?xml version='1.0' encoding='utf-8'?>
<robot name="rotary_cartpole">
<link name="world" />
<link name="base_link">
<inertial>
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
<mass value="0.921" />
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
</inertial>
<visual>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="base_joint" type="fixed">
<parent link="world" />
<child link="base_link" />
</joint>
<link name="arm">
<inertial>
<origin xyz="0.02679980831009001 -0.015110803875962989 -0.005337417994989926" rpy="0 0 0" />
<mass value="0.04052645463607292" />
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="motor_joint" type="revolute">
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
<parent link="base_link" />
<child link="arm" />
<axis xyz="0 0 1" />
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
<dynamics damping="0.001" />
</joint>
<link name="pendulum">
<inertial>
<origin xyz="0.04237886229290564 -0.05762212306183831 -0.0006039324398328591" rpy="0 0 0" />
<mass value="0.056793277126969105" />
<inertia ixx="0.0005956997356690322" iyy="8.986080748758447e-05" izz="0.0006855370063143978" ixy="-1.074983574165622e-05" iyz="0.0" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="pendulum_joint" type="continuous">
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
<parent link="arm" />
<child link="pendulum" />
<axis xyz="0 -1 0" />
<dynamics damping="0.0001" />
</joint>
</robot>

Binary file not shown.

After

Width:  |  Height:  |  Size: 845 KiB

View File

@@ -0,0 +1,101 @@
{
"best_params": {
"actuator_gear_pos": 0.8464986357922419,
"actuator_gear_neg": 1.1837332515116656,
"actuator_filter_tau": 0.09626340711982824,
"motor_damping_pos": 0.013165358841570964,
"motor_damping_neg": 0.015451769431035585,
"motor_armature": 0.0016762295220909746,
"motor_frictionloss_pos": 0.014244285616762612,
"motor_frictionloss_neg": 0.0010053574956085138,
"stribeck_friction_boost": 0.06859381652654695,
"stribeck_vel": 5.279593819777324,
"motor_deadzone_pos": 0.1810971675049997,
"motor_deadzone_neg": 0.20207167371643614,
"action_bias": 0.05656632270757292,
"arm_mass": 0.04052645463607292,
"arm_com_x": 0.02679980831009001,
"arm_com_y": -0.015110803875962989,
"arm_com_z": -0.005337417994989926,
"pendulum_mass": 0.056793277126969105,
"pendulum_com_x": 0.04237886229290564,
"pendulum_com_y": -0.05762212306183831,
"pendulum_com_z": -0.0006039324398328591,
"pendulum_ixx": 0.0005956997356690322,
"pendulum_iyy": 8.986080748758447e-05,
"pendulum_izz": 0.0006855370063143978,
"pendulum_ixy": -1.074983574165622e-05,
"pendulum_damping": 1.1718327130851149e-05,
"pendulum_frictionloss": 7.204244487092445e-05,
"ctrl_limit": 0.6862506602999546
},
"best_cost": 0.9250391788859942,
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260328_153749.npz",
"param_names": [
"actuator_gear_pos",
"actuator_gear_neg",
"actuator_filter_tau",
"motor_damping_pos",
"motor_damping_neg",
"motor_armature",
"motor_frictionloss_pos",
"motor_frictionloss_neg",
"stribeck_friction_boost",
"stribeck_vel",
"motor_deadzone_pos",
"motor_deadzone_neg",
"action_bias",
"arm_mass",
"arm_com_x",
"arm_com_y",
"arm_com_z",
"pendulum_mass",
"pendulum_com_x",
"pendulum_com_y",
"pendulum_com_z",
"pendulum_ixx",
"pendulum_iyy",
"pendulum_izz",
"pendulum_ixy",
"pendulum_damping",
"pendulum_frictionloss",
"ctrl_limit"
],
"defaults": {
"actuator_gear_pos": 0.371194,
"actuator_gear_neg": 0.428143,
"actuator_filter_tau": 0.022301,
"motor_damping_pos": 0.001384,
"motor_damping_neg": 0.005196,
"motor_armature": 0.002753,
"motor_frictionloss_pos": 0.036744,
"motor_frictionloss_neg": 0.069082,
"stribeck_friction_boost": 0.0,
"stribeck_vel": 2.0,
"motor_deadzone_pos": 0.14182,
"motor_deadzone_neg": 0.031454,
"action_bias": 0.0,
"arm_mass": 0.0211,
"arm_com_x": -0.0071,
"arm_com_y": 0.00085,
"arm_com_z": 0.00795,
"pendulum_mass": 0.03937,
"pendulum_com_x": 0.06025,
"pendulum_com_y": -0.07602,
"pendulum_com_z": -0.00346,
"pendulum_ixx": 6.2e-05,
"pendulum_iyy": 3.7e-05,
"pendulum_izz": 7.83e-05,
"pendulum_ixy": -6.93e-06,
"pendulum_damping": 0.0001,
"pendulum_frictionloss": 0.0001,
"ctrl_limit": 0.588
},
"preprocess_vel": true,
"timestamp": "2026-03-28T17:09:39.241413",
"history_summary": {
"first_cost": 13.490216059926947,
"final_cost": 0.9250391788859942,
"generations": 475
}
}

View File

@@ -1,5 +1,5 @@
defaults:
- env: cartpole
- env: rotary_cartpole
- runner: mujoco
- training: ppo
- _self_
- _self_

View File

@@ -1,11 +0,0 @@
max_steps: 500
angle_threshold: 0.418
cart_limit: 2.4
reward_alive: 1.0
reward_pole_upright_scale: 1.0
reward_action_penalty_scale: 0.01
model_path: assets/cartpole/cartpole.urdf
actuators:
- joint: cart_joint
gear: 10.0
ctrl_range: [-1.0, 1.0]

28
configs/env/rotary_cartpole.yaml vendored Normal file
View File

@@ -0,0 +1,28 @@
max_steps: 1000
robot_path: assets/rotary_cartpole
reward_upright_scale: 1.0
alive_bonus: 0.25 # per-step survival bonus (living must beat dying)
balance_bonus: 2.0 # extra reward for upright AND still (beats spinning)
balance_vel_scale: 0.5 # how fast the balance bonus decays with pendulum speed
# ── Regularisation penalties (prevent fast spinning) ─────────────────
motor_vel_penalty: 0.01 # penalise high motor angular velocity
motor_angle_penalty: 0.05 # penalise deviation from centre
action_penalty: 0.05 # penalise large actions (energy cost)
action_rate_penalty: 0.01 # penalise action changes (real-motor smoothness)
# ── Initial state randomisation ──────────────────────────────────────
pendulum_init_range_deg: 180.0 # pendulum starts in [-180°, +180°]
# ── Software safety limit (env-level, always applied) ────────────────
motor_angle_limit_deg: 90.0 # terminate episode if motor exceeds ±90°
# ── HPO search ranges ────────────────────────────────────────────────
hpo:
reward_upright_scale: {min: 0.5, max: 5.0}
motor_vel_penalty: {min: 0.001, max: 0.1}
motor_angle_penalty: {min: 0.01, max: 0.2}
action_penalty: {min: 0.01, max: 0.2}
action_rate_penalty: {min: 0.001, max: 0.1}
pendulum_init_range_deg: {min: 30.0, max: 180.0}
max_steps: {values: [500, 1000, 2000]}

16
configs/runner/mjx.yaml Normal file
View File

@@ -0,0 +1,16 @@
num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu
dt: 0.002
substeps: 10
history_length: 10 # (obs, action) window for implicit adaptation
# ── Domain randomization (sim-to-real) ──────────────────────────────
# Full DR on GPU: latency + sensor noise + per-env dynamics scales
# (friction/damping/torque) are all applied inside the JIT step.
domain_rand:
qpos_noise_std: 0.01 # rad — encoder angle noise
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
action_delay_steps: [0, 2] # control-step latency (040 ms)
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
damping_scale: [0.6, 1.6] # viscous-damping multiplier
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation

View File

@@ -1,4 +1,16 @@
num_envs: 16
device: cpu
dt: 0.02
substeps: 2
num_envs: 64
device: auto # auto = cuda if available, else cpu
dt: 0.002
substeps: 10
history_length: 10 # (obs, action) window for implicit adaptation
# ── Domain randomization (sim-to-real) ──────────────────────────────
# Noise/delay levels anchored to the real recordings (~50 Hz, ~0.5 rad/s
# velocity noise, ≤1-step latency). Set domain_rand: {} to disable.
domain_rand:
qpos_noise_std: 0.01 # rad — encoder angle noise
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
action_delay_steps: [0, 2] # control-step latency (040 ms)
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier
damping_scale: [0.6, 1.6] # viscous-damping multiplier
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation

View File

@@ -0,0 +1,15 @@
# Single-env MuJoCo runner — mimics real hardware timing.
# dt × substeps = 0.002 × 10 = 0.02 s → 50 Hz control, same as serial runner.
num_envs: 1
device: cpu
dt: 0.002
substeps: 10
history_length: 10
# Clean by default (deterministic eval). Confirming-experiment example —
# re-eval an existing checkpoint in sim with a fixed 1-step action delay:
# mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
# checkpoint=runs/.../agent_XXXX.pt \
# '++runner.domain_rand.action_delay_steps=[1,1]'
domain_rand: {}

View File

@@ -0,0 +1,11 @@
# Serial runner — communicates with real hardware over USB/serial.
# Always single-env, CPU-only. Override port on CLI:
# python scripts/train.py runner=serial runner.port=/dev/ttyUSB0
num_envs: 1
device: cpu
port: /dev/cu.usbserial-0001
baud: 115200
dt: 0.02 # control loop period (50 Hz, matches training)
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
history_length: 10 # must match training runner

32
configs/sysid.yaml Normal file
View File

@@ -0,0 +1,32 @@
# System identification defaults.
# Override via CLI: python scripts/sysid.py optimize --max-generations 50
#
# These are NOT Hydra config groups — the sysid scripts use argparse.
# This file serves as documentation and can be loaded by custom wrappers.
capture:
port: /dev/cu.usbserial-0001
baud: 115200
duration: 20.0 # seconds
amplitude: 150 # max PWM magnitude — must match firmware MAX_MOTOR_SPEED
hold_min_ms: 50 # PRBS min hold time
hold_max_ms: 300 # PRBS max hold time
dt: 0.02 # sample period (50 Hz)
optimize:
sigma0: 0.3 # CMA-ES initial step size (in [0,1] normalised space)
population_size: 50 # candidates per generation
max_generations: 1000 # total generations (~4000 evaluations)
sim_dt: 0.002 # MuJoCo physics timestep
substeps: 10 # physics substeps per control step (ctrl_dt = 0.02s)
pos_weight: 1.0 # MSE weight for angle errors
vel_weight: 0.1 # MSE weight for velocity errors
window_duration: 0.5 # multiple-shooting window length (s); 0 = open-loop
seed: 42
# Tunable hardware-realism params (added to ROTARY_CARTPOLE_PARAMS):
# ctrl_limit — effective motor range → exported as ctrl_range in robot.yaml
# motor_deadzone — L298N minimum |action| for torque → exported as deadzone in robot.yaml
# Firmware sends raw (unfiltered) sensor data; EMA filtering is
# handled on the Python side (env transforms) and is NOT part of
# the sysid parameter search.

View File

@@ -1,7 +1,10 @@
hidden_sizes: [128, 128]
total_timesteps: 1000000
rollout_steps: 1024
learning_epochs: 4
# PPO defaults — sized for the CPU MuJoCo runner (64 parallel envs).
# 128 rollout steps × 64 envs ≈ 8K samples per update.
hidden_sizes: [256, 256]
total_timesteps: 500000 # × 64 envs = 32M env steps
rollout_steps: 128
learning_epochs: 5
mini_batches: 4
discount_factor: 0.99
gae_lambda: 0.95
@@ -9,5 +12,31 @@ learning_rate: 0.0003
clip_ratio: 0.2
value_loss_scale: 0.5
entropy_loss_scale: 0.01
log_interval: 10
clearml_project: RL-Framework
kl_threshold: 0.01 # KL-adaptive LR; 0 = fixed learning rate
log_interval: 1000
checkpoint_interval: 50000
initial_log_std: -0.5
min_log_std: -4.0
max_log_std: 2.0
record_video_every: 10000
# History encoder output dim — the window size itself comes from
# runner.history_length (single source of truth).
embedding_dim: 32
# ClearML remote execution (GPU worker)
remote: false
# ── HPO search ranges ────────────────────────────────────────────────
# Read by scripts/hpo.py — ignored by TrainerConfig during training.
hpo:
learning_rate: {min: 0.00005, max: 0.001}
clip_ratio: {min: 0.1, max: 0.3}
discount_factor: {min: 0.98, max: 0.999}
gae_lambda: {min: 0.9, max: 0.99}
entropy_loss_scale: {min: 0.0001, max: 0.1}
value_loss_scale: {min: 0.1, max: 1.0}
learning_epochs: {min: 2, max: 8, type: int}
mini_batches: {values: [2, 4, 8, 16]}

View File

@@ -0,0 +1,23 @@
# PPO sized for MJX (1024+ parallel envs on GPU).
# Inherits defaults + HPO ranges from ppo.yaml.
#
# Short rollouts × many envs is the GPU-PPO sweet spot:
# 24 steps × 1024 envs ≈ 25K samples per update (~6K per mini-batch).
# (The old rollout_steps=2048 inherited from the CPU config meant a
# 2M-sample memory per update — GBs of VRAM and glacial updates.)
defaults:
- ppo
- _self_
rollout_steps: 24
mini_batches: 4
learning_epochs: 5
learning_rate: 0.0003 # KL-adaptive scheduler handles the rest
total_timesteps: 100000 # × 1024 envs ≈ 100M env steps
log_interval: 100
checkpoint_interval: 10000
record_video_every: 10000
remote: false

View File

@@ -0,0 +1,29 @@
# PPO tuned for single-env real-time training on real hardware.
# Inherits defaults + HPO ranges from ppo.yaml.
# ~50 Hz control × 1 env = ~50 timesteps/s.
# 100k timesteps ≈ 33 minutes of wall-clock training.
defaults:
- ppo
- _self_
hidden_sizes: [256, 256]
total_timesteps: 2000000
learning_epochs: 10
learning_rate: 0.0005 # conservative — can't undo real-world damage
entropy_loss_scale: 0.01
rollout_steps: 2048
mini_batches: 8
log_interval: 2048
checkpoint_interval: 5000 # frequent saves — can't rewind real hardware
initial_log_std: -0.5 # moderate initial exploration
min_log_std: -4.0
max_log_std: 2.0 # cap σ at 1.0
# Never run real-hardware training remotely
remote: false
# Tighter HPO ranges for real hardware (override base ppo.yaml ranges)
hpo:
entropy_loss_scale: {min: 0.00005, max: 0.001}
learning_rate: {min: 0.0003, max: 0.003}

View File

@@ -0,0 +1,25 @@
# PPO tuned for single-env simulation — mimics real hardware training.
# Inherits defaults + HPO ranges from ppo.yaml.
# Same 50 Hz control (runner=mujoco_single), 1 env, conservative hypers.
# Sim runs ~100× faster than real time, so we can afford more timesteps.
defaults:
- ppo
- _self_
hidden_sizes: [256, 256]
total_timesteps: 2000000
learning_epochs: 10
learning_rate: 0.0003
entropy_loss_scale: 0.01
rollout_steps: 2048
mini_batches: 8
log_interval: 2048
checkpoint_interval: 10000
initial_log_std: -0.5
min_log_std: -4.0
max_log_std: 2.0
record_video_every: 50000
remote: false

View File

@@ -1,8 +1,21 @@
torch
gymnasium
gymnasium==1.2.3
hydra-core
omegaconf
mujoco
skrl[torch]
mujoco==3.5.0
mujoco-mjx==3.5.0
jax[cuda12]==0.9.1 ; sys_platform == "linux"
jax==0.9.1 ; sys_platform != "linux"
skrl[torch]==1.4.3
clearml
imageio
imageio-ffmpeg
structlog
pyyaml
pyserial
cmaes
matplotlib
smac>=2.0.0
ConfigSpace
hpbandster
pytest

403
scripts/eval.py Normal file
View File

@@ -0,0 +1,403 @@
"""Evaluate a trained policy on real hardware (or in simulation).
Loads a checkpoint and runs the policy in a closed loop. For real
hardware the serial runner talks to the ESP32; for sim it uses the
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
in both modes.
Usage (real hardware):
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Usage (simulation):
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
Controls:
Space — pause / resume policy (motor stops while paused)
R — reset environment
Esc — quit
"""
import math
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from gymnasium import spaces
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from skrl.resources.preprocessors.torch import RunningStandardScaler
from src.core.registry import build_env
from src.models.mlp import SharedMLP
logger = structlog.get_logger()
# ── keyboard state ───────────────────────────────────────────────────
_reset_flag = [False]
_paused = [False]
_quit_flag = [False]
def _key_callback(keycode: int) -> None:
"""Called by MuJoCo viewer on key press."""
if keycode == 32: # GLFW_KEY_SPACE
_paused[0] = not _paused[0]
elif keycode == 82: # GLFW_KEY_R
_reset_flag[0] = True
elif keycode == 256: # GLFW_KEY_ESCAPE
_quit_flag[0] = True
# ── checkpoint loading ───────────────────────────────────────────────
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
"""Infer hidden layer sizes from a SharedMLP state dict."""
sizes = []
i = 0
while f"net.{i}.weight" in state_dict:
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
i += 2 # skip activation layers (ELU)
return tuple(sizes)
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
"""Return the history encoder output dim, if present.
Lets eval reconstruct an embedding policy without knowing the training
embedding_dim — read it straight from the saved weights.
"""
if "history_encoder.fc.weight" in state_dict:
return state_dict["history_encoder.fc.weight"].shape[0]
return None
def load_policy(
checkpoint_path: str,
observation_space: spaces.Space,
action_space: spaces.Space,
device: torch.device = torch.device("cpu"),
history_length: int = 0,
raw_obs_dim: int = 0,
) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
For DR + history-embedding policies (history_length > 0), the history
encoder is reconstructed too — its output dim is read back from the
saved weights.
Returns:
(model, state_preprocessor) ready for inference.
"""
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Infer architecture from saved weights.
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
enc_out = _infer_encoder_out_dim(ckpt["policy"])
# Reconstruct model — pass through the encoder config so a DR+embedding
# checkpoint rebuilds the history encoder with matching dimensions.
model = SharedMLP(
observation_space=observation_space,
action_space=action_space,
device=device,
hidden_sizes=hidden_sizes,
history_length=history_length if enc_out else 0,
raw_obs_dim=raw_obs_dim,
embedding_dim=enc_out or 32,
)
model.load_state_dict(ckpt["policy"])
model.eval()
# Reconstruct observation normalizer.
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
# Freeze the normalizer — don't update stats during eval.
state_preprocessor.training = False
logger.info(
"checkpoint_loaded",
path=checkpoint_path,
hidden_sizes=hidden_sizes,
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
)
return model, state_preprocessor
# ── action arrow overlay ─────────────────────────────────────────────
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
pos = data.xpos[body_id].copy()
pos[2] += 0.02
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
# ── main loops ───────────────────────────────────────────────────────
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None)
if checkpoint_path is None:
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
sys.exit(1)
# Resolve relative paths against original working directory.
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
if not Path(checkpoint_path).exists():
logger.error("checkpoint_not_found", path=checkpoint_path)
sys.exit(1)
if runner_name == "serial":
_eval_serial(cfg, env_name, checkpoint_path)
else:
_eval_sim(cfg, env_name, checkpoint_path)
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy in MuJoCo simulation with viewer."""
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
device = runner.device
model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device,
history_length=runner.config.history_length,
raw_obs_dim=runner.env.observation_space.shape[0],
)
mj_model = runner._model
mj_data = runner._data[0]
dt_ctrl = runner.config.dt * runner.config.substeps
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="simulation",
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = runner.step(action)
episode_reward += reward.item()
step += 1
# Sync viewer
mujoco.mj_forward(mj_model, mj_data)
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 50 == 0:
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
for i in range(mj_model.njnt)}
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1), **joints,
)
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
obs, _ = runner.reset()
step = 0
episode += 1
episode_reward = 0.0
time.sleep(dt_ctrl)
runner.close()
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
device = serial_runner.device
model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
history_length=serial_runner.config.history_length,
raw_obs_dim=serial_runner.env.observation_space.shape[0],
)
# Set up digital-twin MuJoCo model for visualization.
serial_runner._ensure_viz_model()
mj_model = serial_runner._viz_model
mj_data = serial_runner._viz_data
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
obs, _ = serial_runner.reset()
step = 0
episode = 0
episode_reward = 0.0
logger.info(
"eval_started",
env=env_name,
mode="real hardware (serial)",
port=serial_runner.config.port,
checkpoint=Path(checkpoint_path).name,
controls="Space=pause, R=reset, Esc=quit",
)
while viewer.is_running() and not _quit_flag[0]:
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
obs, _ = serial_runner.reset() # drives to center + settles
step = 0
episode += 1
episode_reward = 0.0
logger.info("reset", episode=episode)
if _paused[0]:
serial_runner._send("M0") # safety: stop motor while paused
serial_runner._sync_viz()
viewer.sync()
time.sleep(0.05)
continue
# Policy inference
with torch.no_grad():
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
action = model.act({"states": normalized_obs}, role="policy")[0]
action = action.clamp(-1.0, 1.0)
obs, reward, terminated, truncated, info = serial_runner.step(action)
episode_reward += reward.item()
step += 1
# Sync digital twin with real sensor data.
serial_runner._sync_viz()
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
viewer.sync()
if step % 25 == 0:
state = serial_runner._read_state()
logger.debug(
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1),
motor_deg=round(math.degrees(state["motor_rad"]), 1),
pend_deg=round(math.degrees(state["pend_rad"]), 1),
)
# Check for safety / disconnection.
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
logger.error(
"safety_stop",
reboot=info.get("reboot_detected", False),
motor_limit=info.get("motor_limit_exceeded", False),
)
serial_runner._send("M0")
break
if terminated.any() or truncated.any():
logger.info(
"episode_done", episode=episode, steps=step,
total_reward=round(episode_reward, 2),
reason="terminated" if terminated.any() else "truncated",
)
# Auto-reset for next episode.
obs, _ = serial_runner.reset()
step = 0
episode += 1
episode_reward = 0.0
# Real-time pacing is handled by serial_runner.step() (dt sleep).
serial_runner.close()
if __name__ == "__main__":
main()

442
scripts/hpo.py Normal file
View File

@@ -0,0 +1,442 @@
"""Hyperparameter optimization for RL-Framework using ClearML + SMAC3.
Automatically creates a base training task (via Task.create), reads HPO
search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
and launches SMAC3 Successive Halving optimization.
Usage:
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single
# With HPO-specific options:
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single \\
--queue gpu-queue --total-trials 100
# Or use an existing base task:
python scripts/hpo.py --base-task-id <TASK_ID>
# Dry run (print search space only):
python scripts/hpo.py env=rotary_cartpole --dry-run
"""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import structlog
from clearml import Task
from clearml.automation import (
DiscreteParameterRange,
HyperParameterOptimizer,
UniformIntegerParameterRange,
UniformParameterRange,
)
from omegaconf import OmegaConf
logger = structlog.get_logger()
def _load_hydra_config(
env: str, runner: str, training: str
) -> dict:
"""Load and merge Hydra configs to extract HPO ranges.
We read the YAML files directly (without running Hydra) so this script
doesn't need @hydra.main — it's a ClearML optimizer, not a training job.
"""
configs_dir = Path(__file__).resolve().parent.parent / "configs"
# Load training config (handles defaults: [ppo] inheritance)
training_path = configs_dir / "training" / f"{training}.yaml"
training_cfg = OmegaConf.load(training_path)
# If the training config has defaults pointing to a base, load + merge
if "defaults" in training_cfg:
defaults = OmegaConf.to_container(training_cfg.defaults)
base_cfg = OmegaConf.create({})
for d in defaults:
if isinstance(d, str):
base_path = configs_dir / "training" / f"{d}.yaml"
if base_path.exists():
loaded = OmegaConf.load(base_path)
base_cfg = OmegaConf.merge(base_cfg, loaded)
# Remove defaults key and merge
training_no_defaults = {
k: v for k, v in OmegaConf.to_container(training_cfg).items()
if k != "defaults"
}
training_cfg = OmegaConf.merge(base_cfg, OmegaConf.create(training_no_defaults))
# Load env config
env_path = configs_dir / "env" / f"{env}.yaml"
env_cfg = OmegaConf.load(env_path) if env_path.exists() else OmegaConf.create({})
return {
"training": OmegaConf.to_container(training_cfg, resolve=True),
"env": OmegaConf.to_container(env_cfg, resolve=True),
}
def _build_hyper_parameters(config: dict) -> list:
"""Build ClearML parameter ranges from hpo: blocks in config.
Reads training.hpo and env.hpo dicts and creates appropriate
ClearML parameter range objects.
Each hpo entry can have:
{min, max} → UniformParameterRange (float)
{min, max, type: int} → UniformIntegerParameterRange
{min, max, log: true} → UniformParameterRange with log scale
{values: [...]} → DiscreteParameterRange
"""
params = []
for section in ("training", "env"):
hpo_ranges = config.get(section, {}).get("hpo", {})
if not hpo_ranges:
continue
for param_name, spec in hpo_ranges.items():
hydra_key = f"Hydra/{section}.{param_name}"
if "values" in spec:
params.append(
DiscreteParameterRange(hydra_key, values=spec["values"])
)
elif "min" in spec and "max" in spec:
if spec.get("type") == "int":
params.append(
UniformIntegerParameterRange(
hydra_key,
min_value=int(spec["min"]),
max_value=int(spec["max"]),
)
)
else:
step = spec.get("step", None)
params.append(
UniformParameterRange(
hydra_key,
min_value=float(spec["min"]),
max_value=float(spec["max"]),
step_size=step,
)
)
else:
logger.warning("skipping_unknown_hpo_spec", param=param_name, spec=spec)
return params
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
"""Flatten a nested dict into dot-separated keys.
Example: {"a": {"b": 1}} → {"a.b": 1}
"""
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(_flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
def _create_base_task(
env: str, runner: str, training: str, queue: str
) -> str:
"""Create a base ClearML task without executing it.
Uses Task.create() to register a task pointing at scripts/train.py
with the correct Hydra overrides. The HPO optimizer will clone this.
The full resolved OmegaConf config is attached as Hydra/* parameters
so cloned trial tasks inherit the complete configuration.
"""
script_path = str(Path(__file__).resolve().parent / "train.py")
project_root = str(Path(__file__).resolve().parent.parent)
base_task = Task.create(
project_name="RL-Framework",
task_name=f"{env}-{runner}-{training} (HPO base)",
task_type=Task.TaskTypes.training,
script=script_path,
working_directory=project_root,
argparse_args=[
f"env={env}",
f"runner={runner}",
f"training={training}",
],
add_task_init_call=False,
)
# ── Attach full resolved OmegaConf config ─────────────────────
# ClearML's Hydra binding normally does this when the script runs,
# but Task.create() never executes Hydra. We replicate the binding
# manually: config group choices + all resolved values.
base_task.set_parameter("Hydra/env", env)
base_task.set_parameter("Hydra/runner", runner)
base_task.set_parameter("Hydra/training", training)
# Load and resolve the full config for each group
configs_dir = Path(__file__).resolve().parent.parent / "configs"
for section, name in [("training", training), ("env", env), ("runner", runner)]:
cfg_path = configs_dir / section / f"{name}.yaml"
if not cfg_path.exists():
continue
cfg = OmegaConf.load(cfg_path)
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
if "defaults" in cfg:
defaults = OmegaConf.to_container(cfg.defaults)
base_cfg = OmegaConf.create({})
for d in defaults:
if isinstance(d, str):
base_path = configs_dir / section / f"{d}.yaml"
if base_path.exists():
loaded = OmegaConf.load(base_path)
base_cfg = OmegaConf.merge(base_cfg, loaded)
cfg_no_defaults = {
k: v for k, v in OmegaConf.to_container(cfg).items()
if k != "defaults"
}
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
resolved = OmegaConf.to_container(cfg, resolve=True)
# Remove hpo metadata — not a real config value
resolved.pop("hpo", None)
flat = _flatten_dict(resolved)
for key, value in flat.items():
base_task.set_parameter(f"Hydra/{section}.{key}", value)
# Set docker config
base_task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
),
docker_arguments=[
"-e", "MUJOCO_GL=osmesa",
],
)
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
base_task.set_packages(str(req_file))
task_id = base_task.id
logger.info("base_task_created", task_id=task_id, task_name=base_task.name)
return task_id
def _parse_overrides(argv: list[str]) -> dict[str, str]:
"""Parse Hydra-style key=value overrides from argv.
Returns a dict of parsed key-value pairs. Unknown args (--flags)
are left in argv for argparse to handle.
"""
overrides = {}
remaining = []
for arg in argv:
if "=" in arg and not arg.startswith("-"):
key, value = arg.split("=", 1)
overrides[key] = value
else:
remaining.append(arg)
argv.clear()
argv.extend(remaining)
return overrides
def main() -> None:
# First pass: extract Hydra-style key=value overrides from sys.argv
raw_args = sys.argv[1:]
overrides = _parse_overrides(raw_args)
parser = argparse.ArgumentParser(
description="Hyperparameter optimization for RL-Framework",
usage="%(prog)s env=<ENV> runner=<RUNNER> training=<TRAINING> [options]",
)
parser.add_argument(
"--base-task-id",
type=str,
default=None,
help="Existing ClearML task ID to use as base (skip auto-creation)",
)
parser.add_argument("--queue", type=str, default="gpu-queue")
parser.add_argument(
"--max-concurrent", type=int, default=2,
help="Maximum concurrent trial tasks",
)
parser.add_argument(
"--total-trials", type=int, default=200,
help="Total HPO trial budget",
)
parser.add_argument(
"--min-budget", type=int, default=50_000,
help="Minimum budget (total_timesteps) per trial",
)
parser.add_argument(
"--max-budget", type=int, default=500_000,
help="Maximum budget (total_timesteps) for promoted trials",
)
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
parser.add_argument(
"--max-consecutive-failures", type=int, default=3,
help="Abort HPO after N consecutive trial failures (0 = never abort)",
)
parser.add_argument(
"--time-limit-hours", type=float, default=72,
help="Total wall-clock time limit in hours",
)
parser.add_argument(
"--objective-metric", type=str, default="Reward / Total reward (mean)",
help="ClearML scalar metric title to optimize",
)
parser.add_argument(
"--objective-series", type=str, default=None,
help="ClearML scalar metric series (default: same as title)",
)
parser.add_argument(
"--maximize", action="store_true", default=True,
help="Maximize the objective (default)",
)
parser.add_argument(
"--minimize", action="store_true", default=False,
help="Minimize the objective",
)
parser.add_argument(
"--dry-run", action="store_true",
help="Print search space and exit without running",
)
args = parser.parse_args(raw_args)
# Resolve env/runner/training from Hydra-style overrides (same as train.py)
env = overrides.get("env", "rotary_cartpole")
runner = overrides.get("runner", "mujoco_single")
training = overrides.get("training", "ppo_single")
objective_sign = "min" if args.minimize else "max"
# ── Load config and build search space ────────────────────────
config = _load_hydra_config(env, runner, training)
hyper_parameters = _build_hyper_parameters(config)
if not hyper_parameters:
logger.error(
"no_hpo_ranges_found",
hint="Add 'hpo:' blocks to your training and/or env YAML configs",
)
return
if args.dry_run:
print(f"\nSearch space ({len(hyper_parameters)} parameters):")
for p in hyper_parameters:
print(f" {p.name}: {p}")
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
return
# ── Initialize ClearML HPO task ───────────────────────────────
Task.ignore_requirements("torch")
task = Task.init(
project_name="RL-Framework",
task_name=f"HPO {env}-{runner}-{training}",
task_type=Task.TaskTypes.optimizer,
reuse_last_task_id=False,
)
task.set_base_docker(
docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
docker_arguments=[
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
"-e", "CLEARML_AGENT_FORCE_SYSTEM_SITE_PACKAGES=1",
],
)
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
task.set_packages(str(req_file))
# ── Create or reuse base task ─────────────────────────────────
# Store the base_task_id on the HPO task so that when the services
# worker re-runs this script it reuses the same base task instead
# of creating a duplicate.
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
existing = task.get_parameter("General/base_task_id")
if existing:
base_task_id = existing
logger.info("reusing_base_task_from_param", task_id=base_task_id)
else:
base_task_id = _create_base_task(
env, runner, training, args.queue
)
task.set_parameter("General/base_task_id", base_task_id)
# ── Build objective metric ────────────────────────────────────
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
objective_title = args.objective_metric
objective_series = args.objective_series or objective_title
# ── Launch optimizer ──────────────────────────────────────────
from src.hpo.smac3 import OptimizerSMAC
optimizer = HyperParameterOptimizer(
base_task_id=base_task_id,
hyper_parameters=hyper_parameters,
objective_metric_title=objective_title,
objective_metric_series=objective_series,
objective_metric_sign=objective_sign,
optimizer_class=OptimizerSMAC,
execution_queue=args.queue,
max_number_of_concurrent_tasks=args.max_concurrent,
total_max_jobs=args.total_trials,
min_iteration_per_job=args.min_budget,
max_iteration_per_job=args.max_budget,
pool_period_min=1,
time_limit_per_job=240, # 4 hours per trial max
eta=args.eta,
budget_param_name="Hydra/training.total_timesteps",
max_consecutive_failures=args.max_consecutive_failures,
)
# Send this HPO controller to a remote services worker
task.execute_remotely(queue_name="services", exit_process=True)
# Reporting and time limits
optimizer.set_report_period(1)
optimizer.set_time_limit(in_minutes=int(args.time_limit_hours * 60))
# Start and wait
optimizer.start()
optimizer.wait()
# Get top experiments
max_retries = 5
for attempt in range(max_retries):
try:
top_exp = optimizer.get_top_experiments(top_k=10)
logger.info("top_experiments_retrieved", count=len(top_exp))
for i, t in enumerate(top_exp):
logger.info("top_experiment", rank=i + 1, task_id=t.id, name=t.name)
break
except Exception as e:
logger.warning("retry_get_top_experiments", attempt=attempt + 1, error=str(e))
if attempt < max_retries - 1:
time.sleep(5.0 * (2 ** attempt))
else:
logger.error("could_not_retrieve_top_experiments")
optimizer.stop()
logger.info("hpo_complete")
if __name__ == "__main__":
main()

57
scripts/sysid.py Normal file
View File

@@ -0,0 +1,57 @@
"""Unified CLI for system identification tools.
Usage:
python scripts/sysid.py capture --robot-path assets/rotary_cartpole --duration 20
python scripts/sysid.py optimize --robot-path assets/rotary_cartpole --recording <file>.npz
python scripts/sysid.py visualize --recording <file>.npz
python scripts/sysid.py export --robot-path assets/rotary_cartpole --result <result>.json
"""
from __future__ import annotations
import sys
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
def main() -> None:
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
print(
"Usage: python scripts/sysid.py <command> [options]\n"
"\n"
"Commands:\n"
" capture Record real robot trajectory under PRBS excitation\n"
" optimize Run CMA-ES parameter optimization\n"
" visualize Plot real vs simulated trajectories\n"
" export Write tuned URDF + robot.yaml files\n"
"\n"
"Run 'python scripts/sysid.py <command> --help' for command-specific options."
)
sys.exit(0)
command = sys.argv[1]
# Remove the subcommand from argv so the module's argparse works normally
sys.argv = [f"sysid {command}"] + sys.argv[2:]
if command == "capture":
from src.sysid.capture import main as cmd_main
elif command == "optimize":
from src.sysid.optimize import main as cmd_main
elif command == "visualize":
from src.sysid.visualize import main as cmd_main
elif command == "export":
from src.sysid.export import main as cmd_main
else:
print(f"Unknown command: {command}")
print("Available commands: capture, optimize, visualize, export")
sys.exit(1)
cmd_main()
if __name__ == "__main__":
main()

130
scripts/train.py Normal file
View File

@@ -0,0 +1,130 @@
import os
import pathlib
import sys
# Ensure project root is on sys.path so `src.*` imports work
# regardless of which directory the script is invoked from.
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
# Headless rendering on Linux servers (must be set before mujoco import).
# EGL renders on the GPU directly (right for NVIDIA nodes) and avoids the
# brittle OSMesa/PyOpenGL stack. Forced (not setdefault) so a stale
# `-e MUJOCO_GL=osmesa` baked into a remote task can't override it.
if sys.platform == "linux":
os.environ["MUJOCO_GL"] = "egl"
import hydra
import hydra.utils as hydra_utils
import structlog
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv
from src.core.registry import build_env
from src.core.runner import BaseRunner
from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger()
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
}
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
"""Instantiate the right runner from the Hydra config-group name."""
if runner_name not in RUNNER_REGISTRY:
raise ValueError(
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
)
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
import importlib
mod = importlib.import_module(module_path)
runner_cls = getattr(mod, cls_name)
config_cls = getattr(mod, cfg_cls_name)
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
return runner_cls(env=env, config=runner_config)
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags."""
Task.ignore_requirements("torch")
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
project = "RL-Framework"
task_name = f"{env_name}-{runner_name}-{training_name}"
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
task.set_base_docker(
"git.victormylle.be/victormylle/simple-rl-framework:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libegl1 libgl1 libglfw3 libosmesa6 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0"
),
docker_arguments=[
"-e", "MUJOCO_GL=egl",
],
)
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
task.set_packages(str(req_file))
# Execute remotely if requested and running locally
if remote and task.running_locally():
logger.info("executing_task_remotely", queue="gpu-queue")
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
return task
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
# ClearML init — must happen before heavy work so remote execution
# can take over early. The remote worker re-runs the full script;
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
task = _init_clearml(choices, remote=remote)
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
# resume_from_task_id or any future additions)
import dataclasses as _dc
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
env_name = choices.get("env", "rotary_cartpole")
env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
task.close()
if __name__ == "__main__":
main()

255
scripts/viz.py Normal file
View File

@@ -0,0 +1,255 @@
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
Usage (simulation):
mjpython scripts/viz.py env=rotary_cartpole
mjpython scripts/viz.py env=rotary_cartpole +com=true
Usage (real hardware — digital twin):
mjpython scripts/viz.py env=rotary_cartpole runner=serial
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
Controls:
Left/Right arrows — apply torque to first actuator
R — reset environment
Esc / close window — quit
"""
import math
import sys
import time
from pathlib import Path
# Ensure project root is on sys.path
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import hydra
import mujoco
import mujoco.viewer
import numpy as np
import structlog
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.registry import build_env
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
logger = structlog.get_logger()
# ── keyboard state ───────────────────────────────────────────────────
_action_val = [0.0] # mutable container shared with callback
_action_time = [0.0] # timestamp of last key press
_reset_flag = [False]
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
def _key_callback(keycode: int) -> None:
"""Called by MuJoCo on key press & repeat (not release)."""
if keycode == 263: # GLFW_KEY_LEFT
_action_val[0] = -1.0
_action_time[0] = time.time()
elif keycode == 262: # GLFW_KEY_RIGHT
_action_val[0] = 1.0
_action_time[0] = time.time()
elif keycode == 82: # GLFW_KEY_R
_reset_flag[0] = True
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
"""Draw an arrow on the motor joint showing applied torque direction."""
if abs(action_val) < 0.01 or model.nu == 0:
return
# Get the body that the first actuator's joint belongs to
jnt_id = model.actuator_trnid[0, 0]
body_id = model.jnt_bodyid[jnt_id]
# Arrow origin: body position
pos = data.xpos[body_id].copy()
pos[2] += 0.02 # lift slightly above the body
# Arrow direction: along joint axis in world frame, scaled by action
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
arrow_len = 0.08 * action_val
direction = axis * np.sign(arrow_len)
# Build rotation matrix: arrow rendered along local z-axis
z = direction / (np.linalg.norm(direction) + 1e-8)
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
x = np.cross(up, z)
x /= np.linalg.norm(x) + 1e-8
y = np.cross(z, x)
mat = np.column_stack([x, y, z]).flatten()
# Color: green = positive, red = negative
rgba = np.array(
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
dtype=np.float32,
)
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
mujoco.mjv_initGeom(
geom,
type=mujoco.mjtGeom.mjGEOM_ARROW,
size=np.array([0.008, 0.008, abs(arrow_len)]),
pos=pos,
mat=mat,
rgba=rgba,
)
viewer.user_scn.ngeom += 1
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco")
if runner_name == "serial":
_main_serial(cfg, env_name)
else:
_main_sim(cfg, env_name)
def _main_sim(cfg: DictConfig, env_name: str) -> None:
"""Simulation visualization — step MuJoCo physics with keyboard control."""
# Build env + runner (single env for viz)
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
model = runner._model
data = runner._data[0]
# Control period
dt_ctrl = runner.config.dt * runner.config.substeps
# Launch viewer
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
show_com = cfg.get("com", False)
if show_com:
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
obs, _ = runner.reset()
step = 0
logger.info("viewer_started", env=env_name,
controls="Left/Right arrows = torque, R = reset")
while viewer.is_running():
# Read action from callback (expires after _ACTION_HOLD_S)
if time.time() - _action_time[0] < _ACTION_HOLD_S:
action_val = _action_val[0]
else:
action_val = 0.0
# Reset on R press
if _reset_flag[0]:
_reset_flag[0] = False
obs, _ = runner.reset()
step = 0
logger.info("reset")
# Step through runner
action = torch.tensor([[action_val]])
obs, reward, terminated, truncated, info = runner.step(action)
# Sync viewer with action arrow overlay
mujoco.mj_forward(model, data)
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Print state
if step % 25 == 0:
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
for i in range(model.njnt)}
logger.debug("step", n=step, reward=round(reward.item(), 3),
action=round(action_val, 1), **joints)
# Real-time pacing
time.sleep(dt_ctrl)
step += 1
runner.close()
def _main_serial(cfg: DictConfig, env_name: str) -> None:
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
The MuJoCo model is loaded for rendering only. Joint positions are
read from the ESP32 over serial and applied to the model each frame.
Keyboard arrows send motor commands to the real robot.
"""
from src.runners.serial import SerialRunner, SerialRunnerConfig
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
serial_runner = SerialRunner(
env=env, config=SerialRunnerConfig(**runner_dict)
)
# Load MuJoCo model for visualisation (same URDF the sim uses).
serial_runner._ensure_viz_model()
model = serial_runner._viz_model
data = serial_runner._viz_data
with mujoco.viewer.launch_passive(
model, data, key_callback=_key_callback
) as viewer:
# Show CoM / inertia if requested.
show_com = cfg.get("com", False)
if show_com:
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
logger.info(
"viewer_started",
env=env_name,
mode="serial (digital twin)",
port=serial_runner.config.port,
controls="Left/Right arrows = motor command, R = reset",
)
while viewer.is_running():
# Read action from keyboard callback.
if time.time() - _action_time[0] < _ACTION_HOLD_S:
action_val = _action_val[0]
else:
action_val = 0.0
# Reset on R press.
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_settle()
logger.info("reset (drive-to-center + settle)")
# Send motor command to real hardware (same PWM scaling as
# the policy path: ctrl_range-limited).
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
serial_runner._send(f"M{motor_speed}")
# Sync MuJoCo model with real sensor data.
serial_runner._sync_viz()
# Render overlays and sync viewer.
viewer.user_scn.ngeom = 0
_add_action_arrow(viewer, model, data, action_val)
viewer.sync()
# Real-time pacing (~50 Hz, matches serial dt).
time.sleep(serial_runner.config.dt)
serial_runner.close()
if __name__ == "__main__":
main()

View File

@@ -1,33 +1,25 @@
import abc
import dataclasses
from typing import TypeVar, Generic, Any
from gymnasium import spaces
import numpy as np
import torch
import pathlib
from gymnasium import spaces
from src.core.robot import RobotConfig, load_robot_config
T = TypeVar("T")
@dataclasses.dataclass
class ActuatorConfig:
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
Kept in the env config (not runner config) because actuators define what the robot
can do, which determines action space — a task-level concept.
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
@dataclasses.dataclass
class BaseEnvConfig:
max_steps: int = 1000
model_path: pathlib.Path | None = None
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
robot_path: str = "" # directory containing robot.yaml + URDF
class BaseEnv(abc.ABC, Generic[T]):
def __init__(self, config: BaseEnvConfig):
self.config = config
self.robot: RobotConfig = load_robot_config(config.robot_path)
@property
@abc.abstractmethod
@@ -48,7 +40,9 @@ class BaseEnv(abc.ABC, Generic[T]):
...
@abc.abstractmethod
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor:
def compute_rewards(
self, state: Any, actions: torch.Tensor, prev_actions: torch.Tensor,
) -> torch.Tensor:
...
@abc.abstractmethod
@@ -57,3 +51,26 @@ class BaseEnv(abc.ABC, Generic[T]):
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
return step_counts >= self.config.max_steps
def initial_state_ranges(
self, nq: int, nv: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Per-DOF uniform ranges for initial-state randomization.
Returns (qpos_lo, qpos_hi, qvel_lo, qvel_hi) — offsets added to the
model's default state on every reset. All runners (CPU MuJoCo and
MJX) sample from these, so initial-state distributions stay
identical across backends. Default: small ±0.05 perturbation.
"""
return (
np.full(nq, -0.05), np.full(nq, 0.05),
np.full(nv, -0.05), np.full(nv, 0.05),
)
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
"""Check whether the physical robot has settled enough to start an episode.
Used by the SerialRunner after driving to center and waiting for the
pendulum. Default: always ready (sim doesn't need settling).
"""
return True

22
src/core/registry.py Normal file
View File

@@ -0,0 +1,22 @@
"""Shared env registry and builder used by train.py and viz.py."""
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
env_dict.pop("hpo", None) # HPO range metadata — not an env config field
return env_cls(config_cls(**env_dict))

242
src/core/robot.py Normal file
View File

@@ -0,0 +1,242 @@
"""Robot hardware configuration — loaded from robot.yaml next to the URDF.
Separates robot hardware (actuators, joint tuning) from task config
(rewards, episode length) and from the URDF (clean CAD export).
Usage:
robot = load_robot_config(Path("assets/rotary_cartpole"))
# robot.urdf_path → resolved absolute path to the URDF
# robot.actuators → list of ActuatorConfig
# robot.joints → dict of per-joint overrides
"""
import dataclasses
import math
from pathlib import Path
import structlog
import torch
import yaml
log = structlog.get_logger()
def _as_pair(val) -> tuple[float, float]:
"""Convert scalar or [pos, neg] list to (pos, neg) tuple."""
if isinstance(val, (list, tuple)) and len(val) == 2:
return (float(val[0]), float(val[1]))
return (float(val), float(val))
@dataclasses.dataclass
class ActuatorConfig:
"""Motor/actuator attached to a joint.
Asymmetric fields use (positive_dir, negative_dir) tuples.
A scalar in YAML is expanded to a symmetric pair.
type:
motor — direct torque control (ctrl = normalised torque)
position — PD position servo (ctrl = target angle, needs kp)
velocity — P velocity servo (ctrl = target velocity, needs kp)
"""
joint: str = ""
type: str = "motor"
gear: tuple[float, float] = (1.0, 1.0) # torque constant (pos, neg)
ctrl_range: tuple[float, float] = (-1.0, 1.0) # (lower, upper) control bounds
deadzone: tuple[float, float] = (0.0, 0.0) # min |ctrl| for torque (pos, neg)
damping: tuple[float, float] = (0.0, 0.0) # viscous damping (pos, neg)
frictionloss: tuple[float, float] = (0.0, 0.0) # Coulomb friction (pos, neg)
kp: float = 0.0 # proportional gain (position / velocity actuators)
kv: float = 0.0 # derivative gain (position actuators)
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
viscous_quadratic: float = 0.0 # velocity² drag coefficient
back_emf_gain: float = 0.0 # back-EMF torque reduction
stribeck_friction_boost: float = 0.0 # extra static friction at low speed (N·m)
stribeck_vel: float = 2.0 # Stribeck decay velocity (rad/s)
action_bias: float = 0.0 # additive ctrl bias (driver asymmetry)
@property
def gear_avg(self) -> float:
return (self.gear[0] + self.gear[1]) / 2.0
@property
def has_motor_model(self) -> bool:
"""True if this actuator needs the runtime motor model."""
return (
self.gear[0] != self.gear[1]
or self.deadzone != (0.0, 0.0)
or self.damping != (0.0, 0.0)
or self.frictionloss != (0.0, 0.0)
or self.viscous_quadratic > 0
or self.back_emf_gain > 0
or self.stribeck_friction_boost > 0
or self.action_bias != 0.0
)
def transform_ctrl(self, ctrl: float) -> float:
"""Clip to ctrl_range, then apply bias, deadzone and gear compensation.
Must stay in lock-step with the vectorised JAX version in
``src/runners/mjx.py`` (step_fn) — sysid fits parameters against
THIS function, so any drift breaks the identified model.
"""
# Clip to ctrl_range first (mirrors firmware PWM saturation).
ctrl = max(self.ctrl_range[0], min(self.ctrl_range[1], ctrl))
# Additive driver bias (e.g. H-bridge asymmetry).
ctrl += self.action_bias
# Deadzone
dz_pos, dz_neg = self.deadzone
if ctrl >= 0 and ctrl < dz_pos:
return 0.0
if ctrl < 0 and ctrl > -dz_neg:
return 0.0
# Gear compensation: rescale so ctrl × gear_avg ≈ action × gear_dir
gear_avg = self.gear_avg
if gear_avg > 1e-8:
gear_dir = self.gear[0] if ctrl >= 0 else self.gear[1]
ctrl *= gear_dir / gear_avg
return ctrl
def compute_motor_force(self, vel: float, ctrl: float,
friction_scale: float = 1.0,
damping_scale: float = 1.0) -> float:
"""Asymmetric friction (Coulomb + Stribeck), damping, drag, back-EMF.
``friction_scale`` / ``damping_scale`` multiply the friction and
viscous-damping terms for per-env domain randomization
(1.0 = no randomization, the default used by sysid).
"""
torque = 0.0
# Coulomb + Stribeck friction (direction-dependent). The Stribeck
# boost adds extra friction at low speed that decays as exp(-(v/vs)²)
# — crucial for cheap brushed motors near standstill.
fl_pos, fl_neg = self.frictionloss
if abs(vel) > 1e-6:
fl = fl_pos if vel > 0 else fl_neg
if self.stribeck_friction_boost > 0:
fl += self.stribeck_friction_boost * math.exp(
-((abs(vel) / self.stribeck_vel) ** 2)
)
torque -= math.copysign(fl * friction_scale, vel)
# Viscous damping (direction-dependent)
damp = (self.damping[0] if vel > 0 else self.damping[1]) * damping_scale
torque -= damp * vel
# Quadratic velocity drag
if self.viscous_quadratic > 0:
torque -= self.viscous_quadratic * vel * abs(vel)
# Back-EMF torque reduction
if self.back_emf_gain > 0 and abs(ctrl) > 1e-6:
torque -= self.back_emf_gain * vel * math.copysign(1.0, ctrl)
return max(-10.0, min(10.0, torque))
def transform_action(self, action):
"""Vectorised clip + bias + deadzone + gear compensation (torch batch).
Must produce the same result as ``transform_ctrl`` element-wise.
"""
action = action.clamp(self.ctrl_range[0], self.ctrl_range[1])
action = action + self.action_bias
dz_pos, dz_neg = self.deadzone
if dz_pos > 0 or dz_neg > 0:
pos_dead = (action >= 0) & (action < dz_pos)
neg_dead = (action < 0) & (action > -dz_neg)
action = action.masked_fill(pos_dead | neg_dead, 0.0)
gear_avg = self.gear_avg
if gear_avg > 1e-8 and self.gear[0] != self.gear[1]:
pos = action >= 0
action = torch.where(
pos, action * (self.gear[0] / gear_avg),
action * (self.gear[1] / gear_avg),
)
return action
@dataclasses.dataclass
class JointConfig:
"""Per-joint overrides applied on top of the URDF values."""
damping: float | None = None
armature: float | None = None # reflected rotor inertia (kg·m²)
frictionloss: float | None = None # Coulomb/dry friction torque (N·m)
@dataclasses.dataclass
class RobotConfig:
"""Complete robot hardware description."""
urdf_path: Path = dataclasses.field(default_factory=lambda: Path())
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict)
def load_robot_config(robot_dir: str | Path) -> RobotConfig:
"""Load robot.yaml from a directory and resolve the URDF path.
Expected layout:
robot_dir/
robot.yaml ← hardware config
some_robot.urdf ← CAD export
meshes/ ← optional mesh files
"""
robot_dir = Path(robot_dir).resolve()
yaml_path = robot_dir / "robot.yaml"
if not yaml_path.exists():
raise FileNotFoundError(f"Robot config not found: {yaml_path}")
raw = yaml.safe_load(yaml_path.read_text())
# Resolve URDF path relative to robot.yaml directory
urdf_filename = raw.get("urdf", "")
if not urdf_filename:
raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}")
urdf_path = robot_dir / urdf_filename
if not urdf_path.exists():
raise FileNotFoundError(f"URDF not found: {urdf_path}")
# Parse actuators — ignore unknown keys (newer sysid exports may add
# fields before the loader learns about them) instead of crashing.
known_fields = {f.name for f in dataclasses.fields(ActuatorConfig)}
actuators = []
for a in raw.get("actuators", []):
unknown = set(a) - known_fields
if unknown:
log.warning(
"robot_yaml_unknown_actuator_keys",
keys=sorted(unknown), file=str(yaml_path),
)
a = {k: v for k, v in a.items() if k in known_fields}
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
for key in ("gear", "deadzone", "damping", "frictionloss"):
if key in a:
a[key] = _as_pair(a[key])
actuators.append(ActuatorConfig(**a))
# Parse joint overrides
joints = {}
for name, jcfg in raw.get("joints", {}).items():
joints[name] = JointConfig(**jcfg)
config = RobotConfig(
urdf_path=urdf_path,
actuators=actuators,
joints=joints,
)
log.debug("robot_config_loaded", robot_dir=str(robot_dir),
urdf=urdf_filename, num_actuators=len(actuators),
joint_overrides=list(joints.keys()))
return config

View File

@@ -1,9 +1,12 @@
import dataclasses
import abc
from typing import Any, Generic, TypeVar
from src.core.env import BaseEnv
import numpy as np
import torch
from src.core.env import BaseEnv
T = TypeVar("T")
@@ -11,12 +14,31 @@ T = TypeVar("T")
class BaseRunnerConfig:
num_envs: int = 1
device: str = "cpu"
history_length: int = 0 # 0 = plain obs, >0 = append (obs, action) history
# ── Domain randomization (sim-to-real) ─────────────────────────
# Empty dict = disabled (every field below is a no-op). Supported keys:
# qpos_noise_std: float — Gaussian sensor noise on joint angles (rad)
# qvel_noise_std: float — Gaussian sensor noise on joint velocities (rad/s)
# action_delay_steps: [lo, hi] — per-env integer control-step latency
# friction_scale: [lo, hi] — per-env multiplier on Coulomb friction
# damping_scale: [lo, hi] — per-env multiplier on viscous damping
# torque_scale: [lo, hi] — per-env multiplier on applied motor torque
# With history_length > 0 the policy can implicitly infer the sampled
# dynamics from the recent (obs, action) window — end-to-end adaptation.
domain_rand: dict = dataclasses.field(default_factory=dict)
class BaseRunner(abc.ABC, Generic[T]):
def __init__(self, env: BaseEnv, config: T) -> None:
self.env = env
self.config = config
# Resolve "auto" device before anything uses it
if getattr(self.config, "device", None) == "auto":
self.config.device = "cuda" if torch.cuda.is_available() else "cpu"
self._last_actions: torch.Tensor | None = None
self._sim_initialize(config)
self.observation_space = self.env.observation_space
@@ -27,6 +49,28 @@ class BaseRunner(abc.ABC, Generic[T]):
self.config.num_envs, dtype=torch.long, device=self.config.device
)
# ── Domain randomization (latency / sensor noise / dynamics) ─
self._setup_domain_rand()
# ── History buffer (implicit adaptation input) ────────────
self._history_len: int = getattr(self.config, "history_length", 0)
if self._history_len > 0:
obs_dim = self.observation_space.shape[0]
act_dim = self.action_space.shape[0]
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
# Ring buffer: (num_envs, history_length, obs_dim + act_dim)
self._history_buf = torch.zeros(
self.config.num_envs, self._history_len, self._history_step_dim,
device=self.config.device,
)
# Policy obs = [raw_obs, history_flat]
from gymnasium import spaces
aug_dim = obs_dim + self._history_len * self._history_step_dim
self.observation_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(aug_dim,),
)
@property
@abc.abstractmethod
def num_envs(self) -> int:
@@ -47,51 +91,243 @@ class BaseRunner(abc.ABC, Generic[T]):
@abc.abstractmethod
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Reset the given envs; return FULL-batch (num_envs, nq/nv) state.
Returning the full batch (not just the reset envs) lets GPU
backends hand back zero-copy views without host synchronisation —
the caller indexes the reset rows itself.
"""
...
@abc.abstractmethod
def _sim_close(self) -> None:
...
"""Release simulator resources. Override for extra cleanup."""
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
# ── Domain randomization ─────────────────────────────────────
_SCALE_FIELDS = ("friction_scale", "damping_scale", "torque_scale")
def _setup_domain_rand(self) -> None:
"""Parse the domain_rand config into per-env buffers.
All buffers are no-ops when ``domain_rand`` is empty: scales are 1.0,
delay is 0 and noise std is 0.
"""
dr = dict(getattr(self.config, "domain_rand", {}) or {})
n = self.config.num_envs
dev = self.config.device
# Fixed (not per-env) Gaussian sensor noise.
self._qpos_noise_std = float(dr.get("qpos_noise_std", 0.0))
self._qvel_noise_std = float(dr.get("qvel_noise_std", 0.0))
# Per-env multiplicative dynamics scales (applied by the sim runner).
self._dr_scales: dict[str, torch.Tensor] = {
f: torch.ones(n, device=dev) for f in self._SCALE_FIELDS
}
self._dr_scale_ranges: dict[str, tuple[float, float]] = {}
for f in self._SCALE_FIELDS:
rng = dr.get(f)
if rng:
self._dr_scale_ranges[f] = (float(rng[0]), float(rng[1]))
# Per-env integer action delay (in control steps).
self._dr_delay = torch.zeros(n, dtype=torch.long, device=dev)
delay_range = dr.get("action_delay_steps")
if delay_range:
self._delay_range = (int(delay_range[0]), int(delay_range[1]))
self._max_delay = int(delay_range[1])
else:
self._delay_range = (0, 0)
self._max_delay = 0
# Action-delay ring buffer: (num_envs, max_delay + 1, act_dim).
if self._max_delay > 0:
act_dim = self.env.action_space.shape[0]
self._action_buf = torch.zeros(
n, self._max_delay + 1, act_dim, device=dev,
)
def _resample_domain_rand(self, env_ids: torch.Tensor) -> None:
"""Sample fresh per-env DR factors (call on every (re)set)."""
if env_ids.numel() == 0:
return
dev = self.config.device
for name, (lo, hi) in self._dr_scale_ranges.items():
vals = torch.rand(env_ids.numel(), device=dev) * (hi - lo) + lo
self._dr_scales[name][env_ids] = vals
if self._max_delay > 0:
self._dr_delay[env_ids] = torch.randint(
self._delay_range[0], self._delay_range[1] + 1,
(env_ids.numel(),), device=dev,
)
def _reset_action_buffer(self, env_ids: torch.Tensor) -> None:
if self._max_delay > 0:
self._action_buf[env_ids] = 0.0
def _apply_action_delay(self, actions: torch.Tensor) -> torch.Tensor:
"""Return the per-env delayed action that the simulator should apply.
The policy's commanded action is what gets stored in history; only
the action handed to ``_sim_step`` is delayed.
"""
if self._max_delay <= 0:
return actions
self._action_buf = torch.roll(self._action_buf, 1, dims=1)
self._action_buf[:, 0] = actions
idx = torch.arange(self.num_envs, device=self.device)
return self._action_buf[idx, self._dr_delay]
def _add_sensor_noise(
self, qpos: torch.Tensor, qvel: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self._qpos_noise_std > 0:
qpos = qpos + torch.randn_like(qpos) * self._qpos_noise_std
if self._qvel_noise_std > 0:
qvel = qvel + torch.randn_like(qvel) * self._qvel_noise_std
return qpos, qvel
def _compute_obs(self, qpos: torch.Tensor, qvel: torch.Tensor) -> torch.Tensor:
"""Observation the policy sees — built from noisy (sensor) state."""
nqpos, nqvel = self._add_sensor_noise(qpos, qvel)
return self.env.compute_observations(self.env.build_state(nqpos, nqvel))
# ── Observation augmentation ─────────────────────────────────
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
"""Append the flattened (obs, action) history when enabled."""
if self._history_len <= 0:
return obs
hist_flat = self._history_buf.reshape(obs.shape[0], -1)
return torch.cat([obs, hist_flat], dim=-1)
def _push_history(self, obs: torch.Tensor, actions: torch.Tensor,
env_ids: torch.Tensor | None = None) -> None:
"""Push (obs, action) into the ring buffer (shift left, append right)."""
if self._history_len <= 0:
return
step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1)
if env_ids is None:
# All envs.
self._history_buf = torch.roll(self._history_buf, -1, dims=1)
self._history_buf[:, -1] = step
else:
self._history_buf[env_ids] = torch.roll(
self._history_buf[env_ids], -1, dims=1
)
self._history_buf[env_ids, -1] = step[env_ids]
def _reset_history(self, env_ids: torch.Tensor) -> None:
"""Zero the history buffer for reset envs."""
if self._history_len > 0:
self._history_buf[env_ids] = 0.0
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device)
self._resample_domain_rand(all_ids)
self._reset_action_buffer(all_ids)
qpos, qvel = self._sim_reset(all_ids)
self.step_counts.zero_()
self._reset_history(all_ids)
obs = self._compute_obs(qpos, qvel)
return self._augment_obs(obs), {}
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
return obs, {}
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
qpos, qvel = self._sim_step(actions)
prev_actions = (
self._last_actions
if self._last_actions is not None
else torch.zeros_like(actions)
)
self._last_actions = actions
# Latency: the simulator applies a (per-env) delayed action.
sim_actions = self._apply_action_delay(actions)
qpos, qvel = self._sim_step(sim_actions)
self.step_counts += 1
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
rewards = self.env.compute_rewards(state, actions)
terminated = self.env.compute_terminations(state)
# Reward / termination use the TRUE state (no sensor noise) so the
# learning signal and safety checks stay clean.
clean_state = self.env.build_state(qpos, qvel)
rewards = self.env.compute_rewards(clean_state, actions, prev_actions)
terminated = self.env.compute_terminations(clean_state)
truncated = self.env.compute_truncations(self.step_counts)
# The observation the policy sees is built from the NOISY sensor state.
obs = self._compute_obs(qpos, qvel)
# Push current (obs, action) into history before augmenting.
self._push_history(obs, actions)
info: dict[str, Any] = {}
done = terminated | truncated
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
if done_ids.numel() > 0:
info["final_observations"] = obs[done_ids].clone()
info["final_observations"] = self._augment_obs(obs)[done_ids].clone()
info["final_env_ids"] = done_ids.clone()
reset_qpos, reset_qvel = self._sim_reset(done_ids)
# New episode → fresh dynamics + cleared latency buffer.
self._resample_domain_rand(done_ids)
self._reset_action_buffer(done_ids)
full_qpos, full_qvel = self._sim_reset(done_ids)
self.step_counts[done_ids] = 0
self._reset_history(done_ids)
reset_state = self.env.build_state(reset_qpos, reset_qvel)
obs[done_ids] = self.env.compute_observations(reset_state)
# _sim_reset returns the full batch — index the reset rows here.
obs[done_ids] = self._compute_obs(
full_qpos[done_ids], full_qvel[done_ids],
)
# skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
raise NotImplementedError("Render method not implemented for this runner.")
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
"""Return a raw RGB frame. Override in subclass."""
raise NotImplementedError("Render not implemented for this runner.")
def render(self, env_idx: int = 0) -> np.ndarray:
"""Render frame with action overlay."""
frame = self._render_frame(env_idx)
if self._last_actions is not None:
ctrl = float(self._last_actions[env_idx, 0].clamp(-1.0, 1.0))
_draw_action_overlay(frame, ctrl)
return frame
def close(self) -> None:
self._sim_close()
self._sim_close()
def _draw_action_overlay(frame: np.ndarray, action: float) -> None:
"""Draw an action bar on a rendered frame (no OpenCV needed).
Bar is centered horizontally: green to the right (+), red to the left (-).
"""
h, w = frame.shape[:2]
bar_y = h - 30
bar_h = 16
bar_x_center = w // 2
bar_half_w = w // 4
bar_x_left = bar_x_center - bar_half_w
bar_x_right = bar_x_center + bar_half_w
# Background (dark grey)
frame[bar_y:bar_y + bar_h, bar_x_left:bar_x_right] = [40, 40, 40]
# Filled bar
fill_len = int(abs(action) * bar_half_w)
if action > 0:
color = [60, 200, 60] # green
x0 = bar_x_center
x1 = min(bar_x_center + fill_len, bar_x_right)
else:
color = [200, 60, 60] # red
x1 = bar_x_center
x0 = max(bar_x_center - fill_len, bar_x_left)
frame[bar_y:bar_y + bar_h, x0:x1] = color
# Center tick mark (white)
frame[bar_y:bar_y + bar_h, bar_x_center - 1:bar_x_center + 1] = [255, 255, 255]

View File

@@ -1,53 +0,0 @@
import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class CartPoleState:
cart_pos: torch.float # (num_envs,)
cart_vel: torch.float # (num_envs,)
pole_angle: torch.float # (num_envs,)
pole_vel: torch.float # (num_envs,)
@dataclasses.dataclass
class CartPoleConfig(BaseEnvConfig):
"""CartPole task config. All values come from Hydra YAML."""
angle_threshold: float = 0.418 # ~24 degrees
cart_limit: float = 2.4
reward_alive: float = 1.0
reward_pole_upright_scale: float = 1.0
reward_action_penalty_scale: float = 0.01
class CartPoleEnv(BaseEnv[CartPoleConfig]):
def __init__(self, config: CartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> torch.Tensor:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
@property
def action_space(self) -> torch.Tensor:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
return CartPoleState(
cart_pos=qpos[:, 0],
cart_vel=qvel[:, 0],
pole_angle=qpos[:, 1],
pole_vel=qvel[:, 1],
)
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
return self.config.reward_alive + upright - action_penalty
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
return pole_fallen | cart_out_of_bounds

181
src/envs/rotary_cartpole.py Normal file
View File

@@ -0,0 +1,181 @@
import dataclasses
import math
import numpy as np
import torch
from gymnasium import spaces
from src.core.env import BaseEnv, BaseEnvConfig
@dataclasses.dataclass
class RotaryCartPoleState:
motor_angle: torch.Tensor # (num_envs,)
motor_vel: torch.Tensor # (num_envs,)
pendulum_angle: torch.Tensor # (num_envs,)
pendulum_vel: torch.Tensor # (num_envs,)
@dataclasses.dataclass
class RotaryCartPoleConfig(BaseEnvConfig):
"""Rotary inverted pendulum (Furuta pendulum) task config.
The motor rotates the arm horizontally; the pendulum swings freely
at the arm tip. Goal: swing the pendulum up and balance it upright.
"""
# Reward shaping
reward_upright_scale: float = 1.0 # upright reward ∈ [0, scale]
alive_bonus: float = 0.25 # per-step survival bonus (must stay alive > die)
balance_bonus: float = 2.0 # extra reward for upright AND still (beats spinning)
balance_vel_scale: float = 0.5 # decay rate of the bonus with pendulum speed
motor_vel_penalty: float = 0.01 # penalise high motor angular velocity
motor_angle_penalty: float = 0.05 # penalise deviation from centre
action_penalty: float = 0.05 # penalise large actions (energy cost)
action_rate_penalty: float = 0.01 # penalise action changes (smoothness —
# critical with ~100 ms real motor lag)
# ── Initial state randomisation ──
pendulum_init_range_deg: float = 180.0 # pendulum starts in [-range, +range]
# ── Software safety limit (env-level, on top of URDF + hardware) ──
motor_angle_limit_deg: float = 90.0 # terminate episode if exceeded
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
"""Furuta pendulum / rotary inverted pendulum environment.
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
Observations (6):
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
Using sin/cos avoids discontinuities at ±π for continuous joints.
Actions (1):
Torque applied to the motor_joint (normalised to [-1, 1]).
"""
def __init__(self, config: RotaryCartPoleConfig):
super().__init__(config)
# ── Spaces ───────────────────────────────────────────────────
@property
def observation_space(self) -> spaces.Space:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
@property
def action_space(self) -> spaces.Space:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
# ── State building ───────────────────────────────────────────
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
return RotaryCartPoleState(
motor_angle=qpos[:, 0],
motor_vel=qvel[:, 0],
pendulum_angle=qpos[:, 1],
pendulum_vel=qvel[:, 1],
)
# ── Observations ─────────────────────────────────────────────
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
obs = [
torch.sin(state.motor_angle),
torch.cos(state.motor_angle),
torch.sin(state.pendulum_angle),
torch.cos(state.pendulum_angle),
state.motor_vel,
state.pendulum_vel,
]
return torch.stack(obs, dim=-1)
# ── Rewards ──────────────────────────────────────────────────
def compute_rewards(
self,
state: RotaryCartPoleState,
actions: torch.Tensor,
prev_actions: torch.Tensor,
) -> torch.Tensor:
# Upright shaping ∈ [0, 1]: 0 hanging down (θ=0), 1 fully upright (θ=π).
# Non-negative by design so *surviving* always beats ending the episode early
# (otherwise the optimum is to slam the arm into the ±limit — "suicide policy").
upright = 0.5 * (1.0 - torch.cos(state.pendulum_angle))
# Balanced bonus: large ONLY when near the top AND nearly still. A freely
# spinning pendulum passes through the top at high speed, so stillness≈0 and
# it earns ~none of this — making true balancing strictly dominate the
# "just keep spinning in full loops" local optimum.
stillness = torch.exp(-self.config.balance_vel_scale * state.pendulum_vel.pow(2))
balance = self.config.balance_bonus * upright * stillness
# Per-step alive bonus keeps a not-yet-upright step net-positive despite
# penalties, so the 10 termination penalty is genuinely a deterrent.
reward = (upright * self.config.reward_upright_scale
+ balance
+ self.config.alive_bonus)
# Penalise fast motor spinning (discourages violent oscillation)
reward = reward - self.config.motor_vel_penalty * state.motor_vel.pow(2)
# Penalise motor deviation from centre (keep arm near zero)
reward = reward - self.config.motor_angle_penalty * state.motor_angle.pow(2)
# Penalise large actions (energy efficiency / smoother control)
reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2)
# Penalise rapid action changes — a jittery policy is unrealisable
# through the real motor's ~100 ms lag and excites unmodeled dynamics.
action_rate = (actions - prev_actions).squeeze(-1).pow(2)
reward = reward - self.config.action_rate_penalty * action_rate
# Penalty for exceeding motor angle limit (episode also terminates)
limit_rad = math.radians(self.config.motor_angle_limit_deg)
exceeded = state.motor_angle.abs() >= limit_rad
reward = torch.where(exceeded, torch.tensor(-10.0, device=reward.device), reward)
return reward
# ── Initial state randomization ──────────────────────────────
def initial_state_ranges(
self, nq: int, nv: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Small motor perturbation; wide pendulum angle (swing-up task)."""
qpos_lo = np.full(nq, -0.05)
qpos_hi = np.full(nq, 0.05)
qvel_lo = np.full(nv, -0.05)
qvel_hi = np.full(nv, 0.05)
pend_range = math.radians(self.config.pendulum_init_range_deg)
if pend_range > 0 and nq >= 2:
qpos_lo[1] = -pend_range
qpos_hi[1] = pend_range
return qpos_lo, qpos_hi, qvel_lo, qvel_hi
# ── Terminations ─────────────────────────────────────────────
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# Software safety: terminate if motor angle exceeds limit.
limit_rad = math.radians(self.config.motor_angle_limit_deg)
exceeded = state.motor_angle.abs() >= limit_rad
return exceeded
# ── Reset readiness (for SerialRunner) ───────────────────────
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
"""Pendulum must be hanging still and motor near center."""
motor_angle = float(qpos[0, 0])
pend_angle = float(qpos[0, 1])
motor_vel = float(qvel[0, 0])
pend_vel = float(qvel[0, 1])
# Pendulum near hanging-down (angle ~0) and slow
angle_ok = abs(pend_angle) < math.radians(2.0)
vel_ok = abs(pend_vel) < math.radians(5.0)
# Motor near center and slow
motor_ok = abs(motor_angle) < math.radians(5.0)
motor_vel_ok = abs(motor_vel) < math.radians(10.0)
return angle_ok and vel_ok and motor_ok and motor_vel_ok

1
src/hpo/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Hyperparameter optimization — SMAC3 + ClearML Successive Halving."""

684
src/hpo/smac3.py Normal file
View File

@@ -0,0 +1,684 @@
# Requires: pip install smac==2.0.0 ConfigSpace==0.4.20
import contextlib
import time
from collections.abc import Sequence
from functools import wraps
from typing import Any
from clearml import Task
from clearml.automation.optimization import Objective, SearchStrategy
from clearml.automation.parameters import Parameter
from clearml.backend_interface.session import SendError
from ConfigSpace import (
CategoricalHyperparameter,
ConfigurationSpace,
UniformFloatHyperparameter,
UniformIntegerHyperparameter,
)
from smac import MultiFidelityFacade
from smac.intensifier.successive_halving import SuccessiveHalving
from smac.runhistory.dataclasses import TrialInfo, TrialValue
from smac.scenario import Scenario
def retry_on_error(max_retries=5, initial_delay=2.0, backoff=2.0, exceptions=(Exception,)):
"""Decorator to retry a function on exception with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
delay = initial_delay
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except exceptions:
if attempt == max_retries - 1:
return None # Return None instead of raising
time.sleep(delay)
delay *= backoff
return None
return wrapper
return decorator
def _encode_param_name(name: str) -> str:
"""Encode parameter name for ConfigSpace (replace / with __SLASH__)"""
return name.replace("/", "__SLASH__")
def _decode_param_name(name: str) -> str:
"""Decode parameter name back to original (replace __SLASH__ with /)"""
return name.replace("__SLASH__", "/")
def _convert_param_to_cs(param: Parameter):
"""
Convert a ClearML Parameter into a ConfigSpace hyperparameter,
adapted to ConfigSpace>=1.x (no more 'q' argument).
"""
# Encode the name to avoid ConfigSpace issues with special chars like '/'
name = _encode_param_name(param.name)
# Categorical / discrete list
if hasattr(param, "values"):
return CategoricalHyperparameter(name=name, choices=list(param.values))
# Numeric range (float or int)
if hasattr(param, "min_value") and hasattr(param, "max_value"):
min_val = param.min_value
max_val = param.max_value
# Check if this should be treated as integer
if isinstance(min_val, int) and isinstance(max_val, int):
log = getattr(param, "log_scale", False)
# Check for step_size for quantization
if hasattr(param, "step_size"):
sv = int(param.step_size)
if sv != 1:
# emulate quantization by explicit list of values
choices = list(range(min_val, max_val + 1, sv))
return CategoricalHyperparameter(name=name, choices=choices)
# Simple uniform integer range
return UniformIntegerHyperparameter(name=name, lower=min_val, upper=max_val, log=log)
else:
# Treat as float
lower, upper = float(min_val), float(max_val)
log = getattr(param, "log_scale", False)
return UniformFloatHyperparameter(name=name, lower=lower, upper=upper, log=log)
raise ValueError(f"Unsupported Parameter type: {type(param)}")
class OptimizerSMAC(SearchStrategy):
"""
SMAC3-based hyperparameter optimizer, matching OptimizerBOHB interface.
"""
def __init__(
self,
base_task_id: str,
hyper_parameters: Sequence[Parameter],
objective_metric: Objective,
execution_queue: str,
num_concurrent_workers: int,
min_iteration_per_job: int,
max_iteration_per_job: int,
total_max_jobs: int,
pool_period_min: float = 2.0,
time_limit_per_job: float | None = None,
compute_time_limit: float | None = None,
**smac_kwargs: Any,
):
# Initialize base SearchStrategy
super().__init__(
base_task_id=base_task_id,
hyper_parameters=hyper_parameters,
objective_metric=objective_metric,
execution_queue=execution_queue,
num_concurrent_workers=num_concurrent_workers,
pool_period_min=pool_period_min,
time_limit_per_job=time_limit_per_job,
compute_time_limit=compute_time_limit,
min_iteration_per_job=min_iteration_per_job,
max_iteration_per_job=max_iteration_per_job,
total_max_jobs=total_max_jobs,
)
# Expose for internal use (access private attributes from base class)
self.execution_queue = self._execution_queue
self.min_iterations = min_iteration_per_job
self.max_iterations = max_iteration_per_job
self.num_concurrent_workers = self._num_concurrent_workers # Fix: access private attribute
# Objective details
# Handle both single objective (string) and multi-objective (list) cases
if isinstance(self._objective_metric.title, list):
self.metric_title = self._objective_metric.title[0] # Use first objective
else:
self.metric_title = self._objective_metric.title
if isinstance(self._objective_metric.series, list):
self.metric_series = self._objective_metric.series[0] # Use first series
else:
self.metric_series = self._objective_metric.series
# ClearML Objective stores sign as a list, e.g., ['max'] or ['min']
objective_sign = getattr(self._objective_metric, "sign", None) or getattr(self._objective_metric, "order", None)
# Handle list case - extract first element
if isinstance(objective_sign, list):
objective_sign = objective_sign[0] if objective_sign else "max"
# Default to max if nothing found
if objective_sign is None:
objective_sign = "max"
self.maximize_metric = str(objective_sign).lower() in ("max", "max_global")
# Build ConfigSpace
self.config_space = ConfigurationSpace(seed=42)
for p in self._hyper_parameters: # Access private attribute correctly
cs_hp = _convert_param_to_cs(p)
self.config_space.add(cs_hp)
# Configure SMAC Scenario
scenario = Scenario(
configspace=self.config_space,
n_trials=self.total_max_jobs,
min_budget=float(self.min_iterations),
max_budget=float(self.max_iterations),
walltime_limit=(self.compute_time_limit * 60) if self.compute_time_limit else None,
deterministic=True,
)
# Configurable budget parameter name
# Default: Hydra/training.total_timesteps (RL-Framework convention)
self.budget_param_name = smac_kwargs.pop(
"budget_param_name", "Hydra/training.total_timesteps"
)
# Pop our custom kwargs BEFORE passing smac_kwargs to SuccessiveHalving
self.max_consecutive_failures = int(
smac_kwargs.pop("max_consecutive_failures", 3)
)
self._consecutive_failures = 0
# build the Successive Halving intensifier (NOT Hyperband!)
# Hyperband runs multiple brackets with different starting budgets - wasteful
# Successive Halving: ALL configs start at min_budget, only best get promoted
# eta controls the reduction factor (default 3 means keep top 1/3 each round)
# eta can be overridden via smac_kwargs from HyperParameterOptimizer
eta = smac_kwargs.pop("eta", 3) # Default to 3 if not specified
intensifier = SuccessiveHalving(scenario=scenario, eta=eta, **smac_kwargs)
# now pass that intensifier instance into the facade
self.smac = MultiFidelityFacade(
scenario=scenario,
target_function=lambda config, budget, seed: 0.0,
intensifier=intensifier,
overwrite=True,
)
# Bookkeeping
self.running_tasks = {} # task_id -> trial info
self.task_start_times = {} # task_id -> start time (for timeout)
self.completed_results = []
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
# Checkpoint continuation tracking: config_key -> {budget: task_id}
# Used to find the previous task's checkpoint when promoting a config
self.config_to_tasks = {} # config_key -> {budget: task_id}
# Manual Successive Halving control
self.eta = eta
self.current_budget = float(self.min_iterations)
self.configs_at_budget = {} # budget -> list of (config, score, trial)
self.pending_configs = [] # configs waiting to be evaluated at current_budget - list of (trial, prev_task_id)
self.evaluated_at_budget = [] # (config, score, trial, task_id) for current budget
self.smac_asked_configs = set() # track which configs SMAC has given us
# Calculate initial rung size for proper Successive Halving
# With eta=3: rung sizes are n, n/3, n/9, ...
# Total trials = n * (1 + 1/eta + 1/eta^2 + ...) = n * eta/(eta-1) for infinite series
# For finite rungs, calculate exactly
num_rungs = 1
b = float(self.min_iterations)
while b * eta <= self.max_iterations:
num_rungs += 1
b *= eta
# Sum of geometric series: 1 + 1/eta + 1/eta^2 + ... (num_rungs terms)
series_sum = sum(1.0 / (eta**i) for i in range(num_rungs))
self.initial_rung_size = int(self.total_max_jobs / series_sum)
self.initial_rung_size = max(self.initial_rung_size, self.num_concurrent_workers) # at least num_workers
self.configs_needed_for_rung = self.initial_rung_size # how many configs we still need for current rung
self.rung_closed = False # whether we've collected all configs for current rung
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
def _get_task_safe(self, task_id: str):
"""Safely get a task with retry logic."""
return Task.get_task(task_id=task_id)
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
def _launch_task(self, config: dict, budget: float, prev_task_id: str | None = None):
"""Launch a task with retry logic for robustness.
Args:
config: Hyperparameter configuration dict
budget: Number of epochs to train
prev_task_id: Optional task ID from previous budget to continue from (checkpoint)
"""
base = self._get_task_safe(task_id=self._base_task_id)
if base is None:
return None
clone = Task.clone(
source_task=base,
name=f"HPO Trial - {base.name}",
parent=Task.current_task().id, # Set the current HPO task as parent
)
# Override hyperparameters
for k, v in config.items():
# Decode parameter name back to original (with slashes)
original_name = _decode_param_name(k)
# Convert numpy types to Python built-in types
if hasattr(v, "item"): # numpy scalar
param_value = v.item()
elif isinstance(v, int | float | str | bool):
param_value = type(v)(v) # Ensure it's the built-in type
else:
param_value = v
clone.set_parameter(original_name, param_value)
# Override budget parameter (e.g. total_timesteps) for multi-fidelity
if self.max_iterations != self.min_iterations:
clone.set_parameter(self.budget_param_name, int(budget))
else:
clone.set_parameter(self.budget_param_name, int(self.max_iterations))
# If we have a previous task, pass its ID so the worker can download the checkpoint
if prev_task_id:
clone.set_parameter("Hydra/training.resume_from_task_id", prev_task_id)
Task.enqueue(task=clone, queue_name=self.execution_queue)
# Track start time for timeout enforcement
self.task_start_times[clone.id] = time.time()
return clone
def start(self):
controller = Task.current_task()
total_launched = 0
# Keep launching & collecting until budget exhausted
while total_launched < self.total_max_jobs:
# Check if current budget rung is complete BEFORE asking for new trials
# (no running tasks, no pending configs, and we have results for this budget)
if not self.running_tasks and not self.pending_configs and self.evaluated_at_budget:
# Rung complete! Promote top performers to next budget
# Store results for this budget
self.configs_at_budget[self.current_budget] = self.evaluated_at_budget.copy()
# Sort by score (best first)
sorted_configs = sorted(
self.evaluated_at_budget,
key=lambda x: x[1], # score
reverse=self.maximize_metric,
)
# Print rung results
for _i, (_cfg, _score, _tri, _task_id) in enumerate(sorted_configs[:5], 1):
pass
# Move to next budget?
next_budget = self.current_budget * self.eta
if next_budget <= self.max_iterations:
# How many to promote (top 1/eta)
n_promote = max(1, len(sorted_configs) // self.eta)
promoted = sorted_configs[:n_promote]
# Update budget and reset for next rung
self.current_budget = next_budget
self.evaluated_at_budget = []
self.configs_needed_for_rung = 0 # promoted configs are all we need
self.rung_closed = True # rung is pre-filled with promoted configs
# Re-queue promoted configs with new budget
# Include the previous task ID for checkpoint continuation
for _cfg, _score, old_trial, prev_task_id in promoted:
new_trial = TrialInfo(
config=old_trial.config,
instance=old_trial.instance,
seed=old_trial.seed,
budget=self.current_budget,
)
# Store as tuple: (trial, prev_task_id)
self.pending_configs.append((new_trial, prev_task_id))
else:
# All budgets complete
break
# Fill pending_configs with new trials ONLY if we haven't closed this rung yet
# For the first rung: ask SMAC for initial_rung_size configs total
# For subsequent rungs: only use promoted configs (rung is already closed)
while (
not self.rung_closed
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
< self.initial_rung_size
and total_launched < self.total_max_jobs
):
trial = self.smac.ask()
if trial is None:
self.rung_closed = True
break
# Create new trial with forced budget (TrialInfo is frozen, can't modify)
trial_with_budget = TrialInfo(
config=trial.config,
instance=trial.instance,
seed=trial.seed,
budget=self.current_budget,
)
cfg_key = str(sorted(trial.config.items()))
if cfg_key not in self.smac_asked_configs:
self.smac_asked_configs.add(cfg_key)
# Store as tuple: (trial, None) - no previous task for new configs
self.pending_configs.append((trial_with_budget, None))
# Check if we've collected enough configs for this rung
if (
not self.rung_closed
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
>= self.initial_rung_size
):
self.rung_closed = True
# Launch pending configs up to concurrent limit
while self.pending_configs and len(self.running_tasks) < self.num_concurrent_workers:
# Unpack tuple: (trial, prev_task_id)
trial, prev_task_id = self.pending_configs.pop(0)
t = self._launch_task(trial.config, self.current_budget, prev_task_id=prev_task_id)
if t is None:
# Launch failed, mark trial as failed and continue
# Tell SMAC this trial failed with worst possible score
cost = float("inf") if self.maximize_metric else float("-inf")
self.smac.tell(trial, TrialValue(cost=cost))
total_launched += 1
continue
self.running_tasks[t.id] = trial
# Track which task ID was used for this config at this budget
cfg_key = str(sorted(trial.config.items()))
if cfg_key not in self.config_to_tasks:
self.config_to_tasks[cfg_key] = {}
self.config_to_tasks[cfg_key][self.current_budget] = t.id
total_launched += 1
if not self.running_tasks and not self.pending_configs:
break
# Abort if too many consecutive trials failed (likely a config bug)
if (
self.max_consecutive_failures > 0
and self._consecutive_failures >= self.max_consecutive_failures
):
controller.get_logger().report_text(
f"ABORTING: {self._consecutive_failures} consecutive trial "
f"failures (limit: {self.max_consecutive_failures}). "
"Check the trial logs for errors."
)
# Stop any still-running tasks
for tid in list(self.running_tasks):
with contextlib.suppress(Exception):
t = self._get_task_safe(task_id=tid)
if t:
t.mark_stopped(force=True)
self.running_tasks.clear()
break
# Poll for finished or timed out
done = []
timed_out = []
failed_to_check = []
for tid, _tri in self.running_tasks.items():
try:
task = self._get_task_safe(task_id=tid)
if task is None:
failed_to_check.append(tid)
continue
st = task.get_status()
# Check if task completed normally
if st == Task.TaskStatusEnum.completed or st in (
Task.TaskStatusEnum.failed,
Task.TaskStatusEnum.stopped,
):
done.append(tid)
# Check for timeout (if time limit is set)
elif self.time_limit_per_job and tid in self.task_start_times:
elapsed_minutes = (time.time() - self.task_start_times[tid]) / 60.0
if elapsed_minutes > self.time_limit_per_job:
with contextlib.suppress(Exception):
task.mark_stopped(force=True)
timed_out.append(tid)
except Exception:
# Don't mark as failed immediately, might be transient
# Only mark failed after multiple consecutive failures
if not hasattr(self, "_task_check_failures"):
self._task_check_failures = {}
self._task_check_failures[tid] = self._task_check_failures.get(tid, 0) + 1
if self._task_check_failures[tid] >= 5: # 5 consecutive failures
failed_to_check.append(tid)
del self._task_check_failures[tid]
# Process tasks that failed to check
for tid in failed_to_check:
tri = self.running_tasks.pop(tid)
if tid in self.task_start_times:
del self.task_start_times[tid]
# Tell SMAC this trial failed with worst possible score
res = float("-inf") if self.maximize_metric else float("inf")
cost = -res if self.maximize_metric else res
self.smac.tell(tri, TrialValue(cost=cost))
self.completed_results.append(
{
"task_id": tid,
"config": tri.config,
"budget": tri.budget,
"value": res,
"failed": True,
}
)
# Store result with task_id for checkpoint tracking
self.evaluated_at_budget.append((tri.config, res, tri, tid))
# Process completed tasks
for tid in done:
tri = self.running_tasks.pop(tid)
if tid in self.task_start_times:
del self.task_start_times[tid]
# Clear any accumulated failures for this task
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
del self._task_check_failures[tid]
task = self._get_task_safe(task_id=tid)
# Detect hard-failed tasks (crashed / errored) vs completed
task_failed = False
if task is not None:
st = task.get_status()
task_failed = st in (
Task.TaskStatusEnum.failed,
Task.TaskStatusEnum.stopped,
)
if task is None:
res = float("-inf") if self.maximize_metric else float("inf")
task_failed = True
else:
res = self._get_objective(task)
if res is None or res == float("-inf") or res == float("inf"):
res = float("-inf") if self.maximize_metric else float("inf")
# Track consecutive failures for abort logic
if task_failed:
self._consecutive_failures += 1
else:
self._consecutive_failures = 0 # reset on any success
cost = -res if self.maximize_metric else res
self.smac.tell(tri, TrialValue(cost=cost))
self.completed_results.append(
{
"task_id": tid,
"config": tri.config,
"budget": tri.budget,
"value": res,
}
)
# Store result for this budget rung with task_id for checkpoint tracking
self.evaluated_at_budget.append((tri.config, res, tri, tid))
iteration = len(self.completed_results)
# Always report the trial score (even if it's bad)
if res is not None and res != float("-inf") and res != float("inf"):
controller.get_logger().report_scalar(
title="Optimization", series="trial_score", value=res, iteration=iteration
)
controller.get_logger().report_scalar(
title="Optimization",
series="trial_budget",
value=tri.budget or self.max_iterations,
iteration=iteration,
)
# Update best score tracking based on actual results
if res is not None and res != float("-inf") and res != float("inf"):
if self.maximize_metric:
self.best_score_so_far = max(self.best_score_so_far, res)
elif res < self.best_score_so_far:
self.best_score_so_far = res
# Always report best score so far (shows flat line when no improvement)
if self.best_score_so_far != float("-inf") and self.best_score_so_far != float("inf"):
controller.get_logger().report_scalar(
title="Optimization", series="best_score", value=self.best_score_so_far, iteration=iteration
)
# Report running statistics
valid_scores = [
r["value"]
for r in self.completed_results
if r["value"] is not None and r["value"] != float("-inf") and r["value"] != float("inf")
]
if valid_scores:
controller.get_logger().report_scalar(
title="Optimization",
series="mean_score",
value=sum(valid_scores) / len(valid_scores),
iteration=iteration,
)
controller.get_logger().report_scalar(
title="Progress",
series="completed_trials",
value=len(self.completed_results),
iteration=iteration,
)
controller.get_logger().report_scalar(
title="Progress", series="running_tasks", value=len(self.running_tasks), iteration=iteration
)
# Process timed out tasks (treat as failed with current objective value)
for tid in timed_out:
tri = self.running_tasks.pop(tid)
if tid in self.task_start_times:
del self.task_start_times[tid]
# Clear any accumulated failures for this task
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
del self._task_check_failures[tid]
# Try to get the last objective value before timeout
task = self._get_task_safe(task_id=tid)
if task is None:
res = float("-inf") if self.maximize_metric else float("inf")
else:
res = self._get_objective(task)
if res is None:
res = float("-inf") if self.maximize_metric else float("inf")
cost = -res if self.maximize_metric else res
self.smac.tell(tri, TrialValue(cost=cost))
self.completed_results.append(
{
"task_id": tid,
"config": tri.config,
"budget": tri.budget,
"value": res,
"timed_out": True,
}
)
# Store timed out result for this budget rung with task_id
self.evaluated_at_budget.append((tri.config, res, tri, tid))
time.sleep(self.pool_period_minutes * 60) # Fix: use correct attribute name from base class
if self.compute_time_limit and controller.get_runtime() > self.compute_time_limit * 60:
break
# Finalize
self._finalize()
return self.completed_results
@retry_on_error(max_retries=3, initial_delay=2.0, exceptions=(SendError, ConnectionError, KeyError))
def _get_objective(self, task: Task):
"""Get objective metric value with retry logic for robustness."""
if task is None:
return None
try:
m = task.get_last_scalar_metrics()
if not m:
return None
metric_data = m[self.metric_title][self.metric_series]
# ClearML returns dict with 'last', 'min', 'max' keys representing
# the last/min/max values of this series over ALL logged iterations.
# For snake_length/train_max: 'last' is the last logged train_max value,
# 'max' is the highest train_max ever logged during training.
# Use 'max' if maximizing (we want the best performance achieved),
# 'min' if minimizing, fallback to 'last'
if self.maximize_metric and "max" in metric_data:
result = metric_data["max"]
elif not self.maximize_metric and "min" in metric_data:
result = metric_data["min"]
else:
result = metric_data["last"]
return result
except (KeyError, Exception):
return None
def _finalize(self):
controller = Task.current_task()
# Report final best score
controller.get_logger().report_text(f"Final best score: {self.best_score_so_far}")
# Also try to get SMAC's incumbent for comparison
try:
incumbent = self.smac.intensifier.get_incumbent()
if incumbent is not None:
runhistory = self.smac.runhistory
# Try different ways to get the cost
incumbent_cost = None
try:
incumbent_cost = runhistory.get_cost(incumbent)
except Exception:
# Fallback: search through runhistory manually
for trial_key, trial_value in runhistory.items():
trial_config = runhistory.get_config(trial_key.config_id)
if trial_config == incumbent and (incumbent_cost is None or trial_value.cost < incumbent_cost):
incumbent_cost = trial_value.cost
if incumbent_cost is not None:
score = -incumbent_cost if self.maximize_metric else incumbent_cost
controller.get_logger().report_text(f"SMAC incumbent: {incumbent}, score: {score}")
controller.upload_artifact(
"best_config",
{"config": dict(incumbent), "score": score, "our_best_score": self.best_score_so_far},
)
else:
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
except Exception as e:
controller.get_logger().report_text(f"Error getting SMAC incumbent: {e}")
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})

View File

@@ -3,14 +3,95 @@ import torch.nn as nn
from gymnasium import spaces
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
class HistoryEncoder(nn.Module):
"""1D-CNN encoder over a temporal window of (obs, action) pairs.
Input: (batch, history_length, step_dim)
Output: (batch, embedding_dim)
Architecture: two temporal conv layers → global average pool → linear.
Lets the policy implicitly infer the current dynamics (friction, torque
scale, latency, …) from how the system responded to recent actions —
end-to-end adaptation when trained under domain randomization.
"""
def __init__(
self,
history_length: int,
step_dim: int,
embedding_dim: int = 32,
hidden_channels: int = 32,
) -> None:
super().__init__()
self.conv = nn.Sequential(
# (batch, step_dim, history_length) after transpose
nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1),
nn.ELU(),
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
nn.ELU(),
)
self.fc = nn.Linear(hidden_channels, embedding_dim)
def forward(self, history: torch.Tensor) -> torch.Tensor:
"""history: (batch, history_length, step_dim)."""
# Conv1d expects (batch, channels, seq_len).
x = history.transpose(1, 2)
x = self.conv(x)
# Global average pool over time.
x = x.mean(dim=-1)
return self.fc(x)
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
"""Shared policy/value network with an optional history encoder.
With ``history_length > 0`` the input states are expected to be
``[raw_obs, history_flat]`` (as produced by ``BaseRunner``); the history
window is compressed by a :class:`HistoryEncoder` and concatenated with
the raw observation before the shared backbone.
"""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
device: torch.device,
hidden_sizes: tuple[int, ...] = (32, 32),
clip_actions: bool = False,
clip_log_std: bool = True,
min_log_std: float = -2.0,
max_log_std: float = 2.0,
initial_log_std: float = 0.0,
# ── History encoder ──────────────────────────────────────
history_length: int = 0,
raw_obs_dim: int = 0,
embedding_dim: int = 32,
):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
GaussianMixin.__init__(
self, clip_actions, clip_log_std, min_log_std, max_log_std,
)
DeterministicMixin.__init__(self, clip_actions)
layers = []
in_dim: int = self.num_observations
self._history_length = history_length
self._raw_obs_dim = raw_obs_dim
self._embedding_dim = embedding_dim
self.history_encoder: HistoryEncoder | None = None
if history_length > 0 and raw_obs_dim > 0:
step_dim = raw_obs_dim + self.num_actions
self.history_encoder = HistoryEncoder(
history_length=history_length,
step_dim=step_dim,
embedding_dim=embedding_dim,
)
in_dim = raw_obs_dim + embedding_dim
else:
in_dim = self.num_observations
# ── Shared backbone ──────────────────────────────────────
layers: list[nn.Module] = []
for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_dim, hidden_size))
layers.append(nn.ELU())
@@ -19,30 +100,45 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
# Policy head
self.mean_layer = nn.Linear(in_dim, self.num_actions)
self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std))
self.log_std_parameter: nn.Parameter = nn.Parameter(
torch.full((self.num_actions,), initial_log_std),
)
# Value head
self.value_layer = nn.Linear(in_dim, 1)
self._shared_output: torch.Tensor | None = None
def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]:
def act(
self, inputs: dict[str, torch.Tensor], role: str = "",
) -> tuple[torch.Tensor, ...]:
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)
def _encode(self, states: torch.Tensor) -> torch.Tensor:
"""Optionally split off and encode the history window."""
if self.history_encoder is None:
return self.net(states)
obs = states[:, :self._raw_obs_dim]
hist_flat = states[:, self._raw_obs_dim:]
step_dim = self._raw_obs_dim + self.num_actions
history = hist_flat.reshape(-1, self._history_length, step_dim)
embedding = self.history_encoder(history)
return self.net(torch.cat([obs, embedding], dim=-1))
def compute(
self, inputs: dict[str, torch.Tensor], role: str = ""
) -> tuple[torch.Tensor, ...]:
self, inputs: dict[str, torch.Tensor], role: str = "",
) -> tuple[torch.Tensor, ...]:
if role == "policy":
self._shared_output = self.net(inputs["states"])
self._shared_output = self._encode(inputs["states"])
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
shared_output = (
self._shared_output
if self._shared_output is not None
else self.net(inputs["states"])
else self._encode(inputs["states"])
)
self._shared_output = None
return self.value_layer(shared_output), {}
return self.value_layer(shared_output), {}

354
src/runners/mjx.py Normal file
View File

@@ -0,0 +1,354 @@
"""GPU-batched MuJoCo simulation using MJX (JAX backend).
MJX runs all environments in parallel on GPU via JAX, providing
~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+).
Requirements:
pip install 'jax[cuda12]' # NVIDIA GPU
pip install jax # CPU fallback
"""
import dataclasses
import os
import structlog
import torch
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
# default JAX preallocates ~75% of GPU memory on init, starving torch and
# causing OOM at the first PPO update. Disable preallocation so both libraries
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
try:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
except ImportError as e:
raise ImportError(
"MJX runner requires JAX and MuJoCo MJX. Install with:\n"
" pip install 'jax[cuda12]' # GPU\n"
" pip install jax # CPU\n"
) from e
import numpy as np
from src.core.env import BaseEnv
from src.core.runner import BaseRunner, BaseRunnerConfig
from src.runners.mujoco import (
ActuatorLimits,
load_mujoco_model,
)
log = structlog.get_logger()
@dataclasses.dataclass
class MJXRunnerConfig(BaseRunnerConfig):
num_envs: int = 1024
device: str = "cuda"
dt: float = 0.002
substeps: int = 10
class MJXRunner(BaseRunner[MJXRunnerConfig]):
"""GPU-batched MuJoCo runner using MJX (JAX).
Physics runs entirely on GPU via JAX; observations flow to
PyTorch through zero-copy DLPack transfers.
"""
def __init__(self, env: BaseEnv, config: MJXRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
# ── Initialization ───────────────────────────────────────────────
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
self._mj_model = load_mujoco_model(self.env.robot)
self._mj_model.opt.timestep = config.dt
self._nq = self._mj_model.nq
self._nv = self._mj_model.nv
self._nu = self._mj_model.nu
# Step 2: Put model on GPU
self._mjx_model = mjx.put_model(self._mj_model)
# Step 3: Default reset state on GPU
default_data = mujoco.MjData(self._mj_model)
mujoco.mj_forward(self._mj_model, default_data)
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
# Env-defined initial-state distribution (shared with the CPU
# runner) — baked into the JIT reset as constants.
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
self._nq, self._nv,
)
self._init_qpos_lo = jnp.asarray(qpos_lo)
self._init_qpos_hi = jnp.asarray(qpos_hi)
self._init_qvel_lo = jnp.asarray(qvel_lo)
self._init_qvel_hi = jnp.asarray(qvel_hi)
# Step 4: Initialise all environments with randomized states
self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs)
# Step 4b: Build motor model info (ctrl_idx, qvel_idx, ActuatorConfig)
self._motor_info: list[tuple[int, int]] = []
self._motor_acts: list = []
for ctrl_idx, act in enumerate(self.env.robot.actuators):
if act.has_motor_model:
jnt_id = mujoco.mj_name2id(
self._mj_model, mujoco.mjtObj.mjOBJ_JOINT, act.joint,
)
qvel_idx = self._mj_model.jnt_dofadr[jnt_id]
self._motor_info.append((ctrl_idx, qvel_idx))
self._motor_acts.append(act)
# Step 5: JIT-compile the hot-path functions
self._compile_jit_fns(config.substeps)
# Keep one CPU MjData for offscreen rendering
self._render_data = mujoco.MjData(self._mj_model)
# Per-env DR scale arrays (synced from torch on every reset).
# Initialised to 1.0 here because _setup_domain_rand runs after this.
self._mjx_fr = jnp.ones(config.num_envs)
self._mjx_dp = jnp.ones(config.num_envs)
self._mjx_tq = jnp.ones(config.num_envs)
log.info(
"mjx_runner_ready",
num_envs=config.num_envs,
substeps=config.substeps,
jax_devices=str(jax.devices()),
)
def _make_batched_data(self, n: int):
"""Create *n* environments with env-defined initial randomization."""
self._rng, k1, k2 = jax.random.split(self._rng, 3)
pq = jax.random.uniform(
k1, (n, self._nq),
minval=self._init_qpos_lo, maxval=self._init_qpos_hi,
)
pv = jax.random.uniform(
k2, (n, self._nv),
minval=self._init_qvel_lo, maxval=self._init_qvel_hi,
)
default = self._default_mjx_data
model = self._mjx_model
def _init_one(pq_i, pv_i):
d = default.replace(
qpos=default.qpos + pq_i,
qvel=default.qvel + pv_i,
)
return mjx.forward(model, d)
return jax.vmap(_init_one)(pq, pv)
def _compile_jit_fns(self, substeps: int) -> None:
"""Pre-compile the two hot-path functions so the first call is fast."""
model = self._mjx_model
default = self._default_mjx_data
lim = ActuatorLimits(self._mj_model)
act_jnt_ids = jnp.array(lim.jnt_ids)
act_limited = jnp.array(lim.limited)
act_lo = jnp.array(lim.lo)
act_hi = jnp.array(lim.hi)
act_gs = jnp.array(lim.gear_sign)
# ── Motor model params (JAX arrays for JIT) ─────────────────
# Must stay in lock-step with ActuatorConfig.transform_ctrl() /
# compute_motor_force() in src/core/robot.py — sysid fits against
# the CPU implementation.
_has_motor = len(self._motor_info) > 0
if _has_motor:
acts = self._motor_acts
_ctrl_ids = jnp.array([c for c, _ in self._motor_info])
_qvel_ids = jnp.array([q for _, q in self._motor_info])
_ctrl_lo = jnp.array([a.ctrl_range[0] for a in acts])
_ctrl_hi = jnp.array([a.ctrl_range[1] for a in acts])
_bias = jnp.array([a.action_bias for a in acts])
_dz_pos = jnp.array([a.deadzone[0] for a in acts])
_dz_neg = jnp.array([a.deadzone[1] for a in acts])
_gear_pos = jnp.array([a.gear[0] for a in acts])
_gear_neg = jnp.array([a.gear[1] for a in acts])
_gear_avg = jnp.array([a.gear_avg for a in acts])
_fl_pos = jnp.array([a.frictionloss[0] for a in acts])
_fl_neg = jnp.array([a.frictionloss[1] for a in acts])
_strb_boost = jnp.array([a.stribeck_friction_boost for a in acts])
_strb_vel = jnp.array([a.stribeck_vel for a in acts])
_damp_pos = jnp.array([a.damping[0] for a in acts])
_damp_neg = jnp.array([a.damping[1] for a in acts])
_visc_quad = jnp.array([a.viscous_quadratic for a in acts])
_back_emf = jnp.array([a.back_emf_gain for a in acts])
# ── Batched step (N substeps per call) ──────────────────────
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
# multipliers (1.0 = off). Passed as args (not closure constants) so
# resampling them every episode does NOT trigger JIT recompilation.
@jax.jit
def step_fn(data, fr_scale, dp_scale, tq_scale):
fr = fr_scale[:, None] # broadcast over motor actuators
dp = dp_scale[:, None]
tq = tq_scale[:, None]
# Software limit switch: clamp ctrl once before substeps.
pos = data.qpos[:, act_jnt_ids]
ctrl = data.ctrl
at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0)
at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0)
ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl)
if _has_motor:
# Clip → bias → deadzone → asymmetric gear compensation
# (same order as ActuatorConfig.transform_ctrl).
mc = ctrl[:, _ctrl_ids]
mc = jnp.clip(mc, _ctrl_lo, _ctrl_hi)
mc = mc + _bias
mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc)
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
mc = mc * gear_dir / _gear_avg
mc = mc * tq # torque_scale (DR)
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
data = data.replace(ctrl=ctrl)
def body(_, d):
if _has_motor:
vel = d.qvel[:, _qvel_ids]
mc = d.ctrl[:, _ctrl_ids]
# Coulomb + Stribeck friction (direction-dependent) × DR
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
fl = fl + _strb_boost * jnp.exp(
-((jnp.abs(vel) / _strb_vel) ** 2)
)
fl = fl * fr
torque = -jnp.where(
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
)
# Viscous damping (direction-dependent) × DR scale
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
torque = torque - damp * vel
# Quadratic velocity drag
torque = torque - _visc_quad * vel * jnp.abs(vel)
# Back-EMF torque reduction
bemf = _back_emf * vel * jnp.sign(mc)
torque = torque - jnp.where(
jnp.abs(mc) > 1e-6, bemf, 0.0,
)
torque = jnp.clip(torque, -10.0, 10.0)
d = d.replace(
qfrc_applied=d.qfrc_applied.at[:, _qvel_ids].set(torque),
)
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
return jax.lax.fori_loop(0, substeps, body, data)
self._jit_step = step_fn
# ── Selective reset ─────────────────────────────────────────
init_qpos_lo = self._init_qpos_lo
init_qpos_hi = self._init_qpos_hi
init_qvel_lo = self._init_qvel_lo
init_qvel_hi = self._init_qvel_hi
@jax.jit
def reset_fn(data, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0]
pq = jax.random.uniform(
k1, (ne, default.qpos.shape[0]),
minval=init_qpos_lo, maxval=init_qpos_hi,
)
pv = jax.random.uniform(
k2, (ne, default.qvel.shape[0]),
minval=init_qvel_lo, maxval=init_qvel_hi,
)
m = mask[:, None] # (num_envs, 1) broadcast helper
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
new_ctrl = jnp.where(m, 0.0, data.ctrl)
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
return new_data, rng
self._jit_reset = reset_fn
# ── Step / Reset ─────────────────────────────────────────────────
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# Set ctrl & run N substeps for all environments (with per-env DR scales)
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
self._batch_data = self._jit_step(
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
return qpos, qvel
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Build boolean mask (fixed shape → no JIT recompilation)
mask = torch.zeros(
self.config.num_envs, dtype=torch.bool, device=self.device,
)
mask[env_ids] = True
mask_jax = jnp.from_dlpack(mask)
self._batch_data, self._rng = self._jit_reset(
self._batch_data, mask_jax, self._rng,
)
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
# resamples self._dr_scales just before this call, so re-deriving the
# full arrays here keeps the JAX copies current for every env.
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
# Return the FULL batch (BaseRunner indexes the reset envs in torch)
# — avoids a GPU→CPU sync + JAX gather on every step with a done env.
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
return qpos, qvel
# ── Rendering ────────────────────────────────────────────────────
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render — copies one env's state from GPU to CPU."""
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
self._render_data.ctrl[:] = np.asarray(self._batch_data.ctrl[env_idx])
mujoco.mj_forward(self._mj_model, self._render_data)
if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(
self._mj_model, width=640, height=480,
)
self._offscreen_renderer.update_scene(self._render_data)
return self._offscreen_renderer.render().copy()

View File

@@ -1,19 +1,206 @@
import dataclasses
import os
import tempfile
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
from pathlib import Path
import mujoco
import mujoco.viewer
import numpy as np
import torch
from src.core.env import BaseEnv
from src.core.robot import RobotConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
num_envs: int = 16
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
dt: float = 0.002
substeps: int = 10
class ActuatorLimits:
"""Software limit-switch: cuts motor ctrl when a joint hits its range.
The real robot has physical limit switches that kill motor current
at the travel endpoints. MuJoCo's built-in joint limits only apply
a spring force — they don't zero the actuator signal. This class
replicates the hardware behavior.
"""
def __init__(self, model: mujoco.MjModel) -> None:
jnt_ids = model.actuator_trnid[:model.nu, 0]
self.jnt_ids = jnt_ids
self.limited = model.jnt_limited[jnt_ids].astype(bool)
self.lo = model.jnt_range[jnt_ids, 0]
self.hi = model.jnt_range[jnt_ids, 1]
self.gear_sign = np.sign(model.actuator_gear[:model.nu, 0])
def enforce(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
"""Zero ctrl that would push past joint limits (call every substep)."""
if not self.limited.any():
return
pos = data.qpos[self.jnt_ids]
signed_ctrl = self.gear_sign * data.ctrl[:model.nu]
at_hi = self.limited & (pos >= self.hi) & (signed_ctrl > 0)
at_lo = self.limited & (pos <= self.lo) & (signed_ctrl < 0)
data.ctrl[at_hi | at_lo] = 0.0
# ── Public utilities ─────────────────────────────────────────────────
def load_mujoco_model(robot: RobotConfig) -> mujoco.MjModel:
"""Load a URDF (or MJCF) and apply robot.yaml settings.
Single model-loading entry point for all MuJoCo-based code:
training runners, MJX, and system identification.
Two-step approach required because MuJoCo's URDF parser ignores
``<actuator>`` in the ``<mujoco>`` extension block:
1. Load the URDF -> MuJoCo converts it to internal MJCF
2. Export the MJCF XML, inject actuators + joint overrides, reload
This keeps the URDF clean (re-exportable from CAD) -- all hardware
tuning lives in ``robot.yaml``.
"""
abs_path = robot.urdf_path.resolve()
model_dir = abs_path.parent
is_urdf = abs_path.suffix.lower() == ".urdf"
# -- Step 1: Load URDF with meshdir injection --
if is_urdf:
tree = ET.parse(abs_path)
root = tree.getroot()
# MuJoCo's URDF parser strips directory prefixes from mesh
# filenames, so we inject a <mujoco><compiler meshdir="..."/>
# block. The original URDF stays clean and simulator-agnostic.
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
parent = str(Path(fn).parent)
if parent and parent != ".":
meshdir = parent
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
ET.SubElement(mj_ext, "compiler", attrib={
"meshdir": meshdir,
"balanceinertia": "true",
})
# Write to a temp file (unique name for multiprocessing safety).
fd, tmp_path = tempfile.mkstemp(
suffix=".urdf", prefix="_mj_", dir=str(model_dir),
)
os.close(fd)
try:
tree.write(tmp_path, xml_declaration=True, encoding="unicode")
model_raw = mujoco.MjModel.from_xml_path(tmp_path)
finally:
Path(tmp_path).unlink(missing_ok=True)
else:
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
# If robot.yaml has no actuators/joints, we're done.
if not robot.actuators and not robot.joints:
return model_raw
# -- Step 2: Export MJCF, inject actuators + joint overrides --
fd, tmp_path = tempfile.mkstemp(
suffix=".xml", prefix="_mj_", dir=str(model_dir),
)
os.close(fd)
try:
mujoco.mj_saveLastXML(tmp_path, model_raw)
mjcf_str = Path(tmp_path).read_text()
root = ET.fromstring(mjcf_str)
# -- Inject actuators --
if robot.actuators:
act_elem = ET.SubElement(root, "actuator")
for act in robot.actuators:
attribs = {
"name": f"{act.joint}_{act.type}",
"joint": act.joint,
"gear": str(act.gear_avg),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
}
# dyntype is only available on <general>, not on
# shortcut elements like <motor>/<position>/<velocity>.
use_general = act.filter_tau > 0
if use_general:
attribs["dyntype"] = "filter"
attribs["dynprm"] = str(act.filter_tau)
attribs["gaintype"] = "fixed"
if act.type == "position":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
elif act.type == "velocity":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 0 -{act.kp}"
else: # motor
attribs["biastype"] = "none"
ET.SubElement(act_elem, "general", attrib=attribs)
else:
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, act.type, attrib=attribs)
# -- Apply joint overrides from robot.yaml --
# For actuated joints with a motor model, MuJoCo damping/frictionloss
# are set to 0 — the motor model handles them via qfrc_applied.
joint_damping: dict[str, float] = {}
joint_frictionloss: dict[str, float] = {}
for act in robot.actuators:
if act.has_motor_model:
joint_damping[act.joint] = 0.0
joint_frictionloss[act.joint] = 0.0
joint_armature: dict[str, float] = {}
for name, jcfg in robot.joints.items():
if jcfg.damping is not None:
joint_damping[name] = jcfg.damping
if jcfg.armature is not None:
joint_armature[name] = jcfg.armature
if jcfg.frictionloss is not None:
joint_frictionloss[name] = jcfg.frictionloss
for body in root.iter("body"):
for jnt in body.findall("joint"):
name = jnt.get("name")
if name in joint_damping:
jnt.set("damping", str(joint_damping[name]))
if name in joint_armature:
jnt.set("armature", str(joint_armature[name]))
if name in joint_frictionloss:
jnt.set("frictionloss", str(joint_frictionloss[name]))
# Disable self-collision on all geoms.
for geom in root.iter("geom"):
geom.set("contype", "0")
geom.set("conaffinity", "0")
modified_xml = ET.tostring(root, encoding="unicode")
Path(tmp_path).write_text(modified_xml)
return mujoco.MjModel.from_xml_path(tmp_path)
finally:
Path(tmp_path).unlink(missing_ok=True)
# -- Runner -----------------------------------------------------------
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
@@ -22,76 +209,62 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
@staticmethod
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
"""Load a URDF (or MJCF) file and programmatically inject actuators.
Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, add <actuator> elements, reload
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
# Step 1: Load URDF/MJCF as-is (no actuators)
model_raw = mujoco.MjModel.from_xml_path(model_path)
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation
tmp_mjcf = tempfile.mktemp(suffix=".xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
finally:
import os
os.unlink(tmp_mjcf)
# Step 3: Inject actuators into the MJCF XML
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Step 4: Reload from modified MJCF
modified_xml = ET.tostring(root, encoding="unicode")
return mujoco.MjModel.from_xml_string(modified_xml)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified in the environment config")
actuators = self.env.config.actuators
self._model = self._load_model_with_actuators(str(model_path), actuators)
self._model = load_mujoco_model(self.env.robot)
self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
self._data: list[mujoco.MjData] = [
mujoco.MjData(self._model) for _ in range(config.num_envs)
]
self._nq = self._model.nq
self._nv = self._model.nv
self._limits = ActuatorLimits(self._model)
# Build motor model: list of (actuator_config, joint_qvel_index) for
# actuators that have asymmetric motor dynamics.
self._motor_actuators: list[tuple] = []
for act in self.env.robot.actuators:
if act.has_motor_model:
jnt_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_JOINT, act.joint)
qvel_idx = self._model.jnt_dofadr[jnt_id]
self._motor_actuators.append((act, qvel_idx))
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
# Apply per-actuator ctrl transform (deadzone + gear compensation)
for act_idx, (act, _) in enumerate(self._motor_actuators):
for env_idx in range(self.num_envs):
actions_np[env_idx, act_idx] = act.transform_ctrl(
float(actions_np[env_idx, act_idx])
)
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
# Per-env domain-randomization scales (all 1.0 when DR is disabled).
fr_scale = self._dr_scales["friction_scale"].cpu().numpy()
dp_scale = self._dr_scales["damping_scale"].cpu().numpy()
tq_scale = self._dr_scales["torque_scale"].cpu().numpy()
for i, data in enumerate(self._data):
data.ctrl[:] = actions_np[i]
# torque_scale emulates motor-constant / battery-voltage variation.
data.ctrl[:] = actions_np[i] * tq_scale[i]
for _ in range(self.config.substeps):
# Apply asymmetric motor forces via qfrc_applied
for act, qvel_idx in self._motor_actuators:
vel = data.qvel[qvel_idx]
ctrl = data.ctrl[0] # TODO: generalise for multi-actuator
data.qfrc_applied[qvel_idx] = act.compute_motor_force(
vel, ctrl,
friction_scale=fr_scale[i],
damping_scale=dp_scale[i],
)
self._limits.enforce(self._model, data)
mujoco.mj_step(self._model, data)
qpos_batch[i] = data.qpos
@@ -101,55 +274,37 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ids = env_ids.cpu().numpy()
n = len(ids)
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
# Env-defined initial-state distribution (shared with the MJX runner).
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
self._nq, self._nv,
)
for i, env_id in enumerate(ids):
for env_id in ids:
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Add small random perturbation so the pole doesn't start perfectly upright
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
data.qpos[:] += np.random.uniform(qpos_lo, qpos_hi)
data.qvel[:] += np.random.uniform(qvel_lo, qvel_hi)
data.ctrl[:] = 0.0
# Full-batch return (see BaseRunner._sim_reset contract).
qpos_batch = np.stack([d.qpos for d in self._data]).astype(np.float32)
qvel_batch = np.stack([d.qvel for d in self._data]).astype(np.float32)
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_close(self) -> None:
if hasattr(self, "_viewer") and self._viewer is not None:
self._viewer.close()
self._viewer = None
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
if mode == "human":
if not hasattr(self, "_viewer") or self._viewer is None:
self._viewer = mujoco.viewer.launch_passive(
self._model, self._data[env_idx]
)
# Update visual geometry from current physics state
mujoco.mj_forward(self._model, self._data[env_idx])
self._viewer.sync()
return None
elif mode == "rgb_array":
# Cache the offscreen renderer to avoid create/destroy overhead
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
return torch.from_numpy(pixels)
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render of a single environment."""
if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(
self._model, width=640, height=480,
)
mujoco.mj_forward(self._model, self._data[env_idx])
self._offscreen_renderer.update_scene(self._data[env_idx])
return self._offscreen_renderer.render().copy()

494
src/runners/serial.py Normal file
View File

@@ -0,0 +1,494 @@
"""Serial runner — real hardware over USB/serial (ESP32).
Implements the BaseRunner interface for a single physical robot.
All physics come from the real world; the runner translates between
the ESP32 serial protocol and the qpos/qvel tensors that BaseRunner
and BaseEnv expect.
Serial protocol (ESP32 firmware):
Commands sent TO the ESP32:
G — start streaming state lines
H — stop streaming
M<int> — set motor PWM speed (-255 … 255)
State lines received FROM the ESP32 (firmware sends SI units):
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
(7 comma-separated fields after the ``S`` prefix)
motor_rad — motor joint angle (radians)
motor_vel — motor joint velocity (rad/s)
pend_rad — pendulum angle (radians, 0 = hanging down)
pend_vel — pendulum angular velocity (rad/s)
motor_speed — applied PWM (-255..255, for action recording)
A daemon thread continuously reads the serial stream so the control
loop never blocks on I/O.
Usage:
python train.py env=rotary_cartpole runner=serial training=ppo_real
"""
from __future__ import annotations
import dataclasses
import logging
import threading
import time
from typing import Any
import numpy as np
import torch
from src.core.runner import BaseRunner, BaseRunnerConfig
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class SerialRunnerConfig(BaseRunnerConfig):
"""Configuration for serial communication with the ESP32."""
num_envs: int = 1 # always 1 — single physical robot
device: str = "cpu"
port: str = "/dev/cu.usbserial-0001"
baud: int = 115200
dt: float = 0.04 # control loop period (seconds), 25 Hz
no_data_timeout: float = 2.0 # seconds of silence → disconnect
# Physical reset procedure
reset_drive_speed: int = 80 # PWM for bang-bang drive-to-center
reset_deadband_rad: float = 0.01 # "centered" threshold (~0.6°)
reset_drive_timeout: float = 3.0 # seconds to reach center
reset_settle_timeout: float = 30.0 # seconds to wait for pendulum
class SerialRunner(BaseRunner[SerialRunnerConfig]):
"""BaseRunner implementation that talks to real hardware over serial.
Maps the ESP32 serial protocol to qpos/qvel tensors so the existing
RotaryCartPoleEnv (or any compatible env) works unchanged.
qpos layout: [motor_angle_rad, pendulum_angle_rad]
qvel layout: [motor_vel_rad_s, pendulum_vel_rad_s]
"""
# ------------------------------------------------------------------
# BaseRunner interface
# ------------------------------------------------------------------
@property
def num_envs(self) -> int:
return 1
@property
def device(self) -> torch.device:
return torch.device("cpu")
def _sim_initialize(self, config: SerialRunnerConfig) -> None:
# Joint dimensions for the rotary cartpole (motor + pendulum).
self._nq = 2
self._nv = 2
# Import serial here so it's not a hard dependency for sim-only users.
import serial as _serial
self._serial_mod = _serial
# Explicitly disable hardware flow control and exclusive mode to
# avoid termios.error (errno 22) on macOS with CH340/CP2102 adapters.
self.ser: _serial.Serial = _serial.Serial(
port=config.port,
baudrate=config.baud,
timeout=0.05,
xonxoff=False,
rtscts=False,
dsrdtr=False,
exclusive=False,
)
time.sleep(2) # Wait for ESP32 boot.
self.ser.reset_input_buffer()
# Internal state tracking.
self._rebooted: bool = False
self._serial_disconnected: bool = False
self._last_esp_ms: int = 0
self._last_data_time: float = time.monotonic()
self._streaming: bool = False
# Latest parsed state (updated by the reader thread).
# Firmware sends SI units — values are used directly as qpos/qvel.
self._latest_state: dict[str, Any] = {
"timestamp_ms": 0,
"motor_rad": 0.0,
"motor_vel": 0.0,
"pend_rad": 0.0,
"pend_vel": 0.0,
"motor_speed": 0,
}
self._state_lock = threading.Lock()
self._state_event = threading.Event()
# Start background serial reader.
self._reader_running = True
self._reader_thread = threading.Thread(
target=self._serial_reader, daemon=True
)
self._reader_thread.start()
# Start streaming.
self._send("G")
self._streaming = True
self._last_data_time = time.monotonic()
# Derive max PWM from actuator ctrl_range so the serial
# command range matches what MuJoCo training sees.
ctrl_hi = self.env.robot.actuators[0].ctrl_range[1]
self._max_pwm: int = round(ctrl_hi * 255)
# Track wall-clock time of last step for PPO-gap detection.
self._last_step_time: float = time.monotonic()
def _sim_step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
now = time.monotonic()
# Detect PPO update gap: if more than 0.5s since last step,
# the optimizer was running and no motor commands were sent.
# Trigger a full reset so the robot starts from a clean state.
gap = now - self._last_step_time
if gap > 0.5:
logger.info(
"PPO update gap detected (%.1f s) — resetting before resuming.",
gap,
)
self._send("M0")
all_ids = torch.arange(self.num_envs, device=self.device)
self._sim_reset(all_ids)
self.step_counts.zero_()
# Map normalised action [-1, 1] → PWM, scaled by ctrl_range.
action_val = float(actions[0, 0].clamp(-1.0, 1.0))
motor_speed = int(action_val * self._max_pwm)
self._send(f"M{motor_speed}")
# Stream-driven: block until the firmware sends the next state
# line (~20 ms at 50 Hz).
state = self._read_state_blocking(timeout=0.1)
# Firmware sends SI units — use directly.
qpos, qvel = self._state_to_tensors(state)
# Cache for _sync_viz().
self._last_sync_state = state
self._last_step_time = time.monotonic()
return qpos, qvel
def _sim_reset(
self, env_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# If ESP32 rebooted or disconnected, we can't recover.
if self._rebooted or self._serial_disconnected:
raise RuntimeError(
"ESP32 rebooted or disconnected during training! "
"Encoder center is lost. "
"Please re-center the motor manually and restart."
)
# Stop motor and restart streaming.
self._send("M0")
self._send("H")
self._streaming = False
time.sleep(0.05)
self._state_event.clear()
self._send("G")
self._streaming = True
self._last_data_time = time.monotonic()
time.sleep(0.05)
# Physically return the motor to the centre position.
self._drive_to_center()
# Wait until the env considers the robot settled.
self._wait_for_settle()
# Refresh data timer so health checks don't false-positive.
self._last_data_time = time.monotonic()
# Read settled state and return as qpos/qvel.
state = self._read_state_blocking()
qpos, qvel = self._state_to_tensors(state)
self._last_sync_state = state
return qpos, qvel
def _sim_close(self) -> None:
self._reader_running = False
self._streaming = False
self._send("H")
self._send("M0")
time.sleep(0.1)
self._reader_thread.join(timeout=1.0)
self.ser.close()
super()._sim_close()
# ------------------------------------------------------------------
# MuJoCo digital-twin rendering
# ------------------------------------------------------------------
def _ensure_viz_model(self) -> None:
"""Lazily load the MuJoCo model for visualisation (digital twin)."""
if hasattr(self, "_viz_model"):
return
import mujoco
from src.runners.mujoco import load_mujoco_model
self._viz_model = load_mujoco_model(self.env.robot)
self._viz_data = mujoco.MjData(self._viz_model)
self._offscreen_renderer = None
def _sync_viz(self) -> None:
"""Copy current serial sensor state into the MuJoCo viz model."""
import mujoco
self._ensure_viz_model()
last_state = getattr(self, "_last_sync_state", None)
if last_state is None:
last_state = self._read_state()
# Firmware sends radians — use directly.
self._viz_data.qpos[0] = last_state["motor_rad"]
self._viz_data.qpos[1] = last_state["pend_rad"]
self._viz_data.qvel[0] = last_state["motor_vel"]
self._viz_data.qvel[1] = last_state["pend_vel"]
mujoco.mj_forward(self._viz_model, self._viz_data)
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render of the digital-twin MuJoCo model."""
import mujoco
self._sync_viz()
if self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(
self._viz_model, width=640, height=480,
)
self._offscreen_renderer.update_scene(self._viz_data)
return self._offscreen_renderer.render().copy()
# ------------------------------------------------------------------
# Override step() for runner-level safety
# ------------------------------------------------------------------
def step(
self, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
# Check for ESP32 reboot / disconnect BEFORE stepping.
if self._rebooted or self._serial_disconnected:
self._send("M0")
qpos, qvel = self._make_current_state()
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
reward = torch.tensor([[-100.0]])
terminated = torch.tensor([[True]])
truncated = torch.tensor([[False]])
return obs, reward, terminated, truncated, {"reboot_detected": True}
# Normal step via BaseRunner (calls _sim_step → env logic).
obs, rewards, terminated, truncated, info = super().step(actions)
# Check connection health after stepping.
if not self._check_connection_health():
self._send("M0")
terminated = torch.tensor([[True]])
rewards = torch.tensor([[-100.0]])
info["reboot_detected"] = True
# Always stop motor on episode end.
if terminated.any() or truncated.any():
self._send("M0")
return obs, rewards, terminated, truncated, info
# ------------------------------------------------------------------
# Serial helpers
# ------------------------------------------------------------------
def _send(self, cmd: str) -> None:
"""Send a command to the ESP32."""
try:
self.ser.write(f"{cmd}\n".encode())
except (OSError, self._serial_mod.SerialException):
self._serial_disconnected = True
def _serial_reader(self) -> None:
"""Background thread: continuously read and parse serial lines.
Protocol: ``S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>``
(7 comma-separated fields). Firmware sends SI units directly.
"""
while self._reader_running:
try:
if self.ser.in_waiting:
line = (
self.ser.readline()
.decode("utf-8", errors="ignore")
.strip()
)
# Detect ESP32 reboot: it prints READY on startup.
if line.startswith("READY"):
self._rebooted = True
logger.critical("ESP32 reboot detected: %s", line)
continue
if line.startswith("S,"):
parts = line.split(",")
if len(parts) >= 7:
try:
esp_ms = int(parts[1])
except ValueError:
logger.debug(
"Malformed state line (header): %s",
line,
)
continue
# Detect reboot: timestamp jumped backwards.
if (
self._last_esp_ms > 5000
and esp_ms < self._last_esp_ms - 3000
):
self._rebooted = True
logger.critical(
"ESP32 reboot detected: timestamp"
" %d -> %d",
self._last_esp_ms,
esp_ms,
)
self._last_esp_ms = esp_ms
self._last_data_time = time.monotonic()
try:
parsed: dict[str, Any] = {
"timestamp_ms": esp_ms,
"motor_rad": float(parts[2]),
"motor_vel": float(parts[3]),
"pend_rad": float(parts[4]),
"pend_vel": float(parts[5]),
"motor_speed": int(parts[6]),
}
except ValueError:
logger.debug(
"Malformed state line (fields): %s",
line,
)
continue
with self._state_lock:
self._latest_state = parsed
self._state_event.set()
else:
time.sleep(0.001) # Avoid busy-spinning.
except (OSError, self._serial_mod.SerialException) as exc:
self._serial_disconnected = True
logger.critical("Serial connection lost: %s", exc)
break
def _check_connection_health(self) -> bool:
"""Return True if the ESP32 connection appears healthy."""
if self._serial_disconnected:
logger.critical("ESP32 serial connection lost.")
return False
if (
self._streaming
and (time.monotonic() - self._last_data_time)
> self.config.no_data_timeout
):
logger.critical(
"No data from ESP32 for %.1f s — possible crash/disconnect.",
time.monotonic() - self._last_data_time,
)
self._rebooted = True
return False
return True
def _read_state(self) -> dict[str, Any]:
"""Return the most recent state from the reader thread (non-blocking)."""
with self._state_lock:
return dict(self._latest_state)
def _read_state_blocking(self, timeout: float = 0.05) -> dict[str, Any]:
"""Wait for a fresh sample, then return it."""
self._state_event.clear()
self._state_event.wait(timeout=timeout)
with self._state_lock:
return dict(self._latest_state)
def _state_to_tensors(
self, state: dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert a parsed state dict to (qpos, qvel) tensors."""
qpos = torch.tensor(
[[state["motor_rad"], state["pend_rad"]]], dtype=torch.float32
)
qvel = torch.tensor(
[[state["motor_vel"], state["pend_vel"]]], dtype=torch.float32
)
return qpos, qvel
def _make_current_state(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Build qpos/qvel from current sensor data (utility)."""
return self._state_to_tensors(self._read_state_blocking())
# ------------------------------------------------------------------
# Physical reset helpers
# ------------------------------------------------------------------
def _drive_to_center(self) -> None:
"""Drive the motor back toward center using bang-bang control."""
cfg = self.config
start = time.time()
while time.time() - start < cfg.reset_drive_timeout:
state = self._read_state_blocking()
motor_rad = state["motor_rad"]
if abs(motor_rad) < cfg.reset_deadband_rad:
break
speed = cfg.reset_drive_speed if motor_rad < 0 else -cfg.reset_drive_speed
self._send(f"M{speed}")
time.sleep(0.05)
self._send("M0")
time.sleep(0.2)
def _wait_for_settle(self) -> None:
"""Block until the env considers the robot ready for a new episode."""
cfg = self.config
stable_since: float | None = None
start = time.monotonic()
while time.monotonic() - start < cfg.reset_settle_timeout:
state = self._read_state_blocking()
qpos, qvel = self._state_to_tensors(state)
if self.env.is_reset_ready(qpos, qvel):
if stable_since is None:
stable_since = time.monotonic()
elif time.monotonic() - stable_since >= 0.5:
logger.info(
"Robot settled after %.2f s",
time.monotonic() - start,
)
return
else:
stable_since = None
time.sleep(0.02)
logger.warning(
"Robot did not settle within %.1f s — proceeding anyway.",
cfg.reset_settle_timeout,
)

1
src/sysid/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""System identification — tune simulation parameters to match real hardware."""

79
src/sysid/_urdf.py Normal file
View File

@@ -0,0 +1,79 @@
"""URDF XML helpers shared by sysid rollout and export modules."""
from __future__ import annotations
import xml.etree.ElementTree as ET
def set_mass(inertial: ET.Element, mass: float | None) -> None:
if mass is None:
return
mass_el = inertial.find("mass")
if mass_el is not None:
mass_el.set("value", str(mass))
def set_com(
inertial: ET.Element,
x: float | None,
y: float | None,
z: float | None,
) -> None:
origin = inertial.find("origin")
if origin is None:
return
xyz = origin.get("xyz", "0 0 0").split()
if x is not None:
xyz[0] = str(x)
if y is not None:
xyz[1] = str(y)
if z is not None:
xyz[2] = str(z)
origin.set("xyz", " ".join(xyz))
def set_inertia(
inertial: ET.Element,
ixx: float | None = None,
iyy: float | None = None,
izz: float | None = None,
ixy: float | None = None,
iyz: float | None = None,
ixz: float | None = None,
) -> None:
ine = inertial.find("inertia")
if ine is None:
return
for attr, val in [
("ixx", ixx), ("iyy", iyy), ("izz", izz),
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
]:
if val is not None:
ine.set(attr, str(val))
def patch_link_inertials(
root: ET.Element,
params: dict[str, float],
) -> None:
"""Patch arm and pendulum inertial parameters in a URDF ElementTree root."""
for link in root.iter("link"):
link_name = link.get("name", "")
inertial = link.find("inertial")
if inertial is None:
continue
if link_name == "arm":
set_mass(inertial, params.get("arm_mass"))
set_com(inertial, params.get("arm_com_x"),
params.get("arm_com_y"), params.get("arm_com_z"))
elif link_name == "pendulum":
set_mass(inertial, params.get("pendulum_mass"))
set_com(inertial, params.get("pendulum_com_x"),
params.get("pendulum_com_y"), params.get("pendulum_com_z"))
set_inertia(inertial,
ixx=params.get("pendulum_ixx"),
iyy=params.get("pendulum_iyy"),
izz=params.get("pendulum_izz"),
ixy=params.get("pendulum_ixy"))

434
src/sysid/capture.py Normal file
View File

@@ -0,0 +1,434 @@
"""Capture a real-robot trajectory under random excitation (PRBS-style).
Connects to the ESP32 over serial, sends random PWM commands to excite
the system, and records motor + pendulum angles and velocities at ~50 Hz.
Saves a compressed numpy archive (.npz) that the optimizer can replay
in simulation to fit physics parameters.
Serial protocol (same as SerialRunner):
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
(7 comma-separated fields — firmware sends SI units)
Usage:
python -m src.sysid.capture \
--robot-path assets/rotary_cartpole \
--port /dev/cu.usbserial-0001 \
--duration 20
"""
from __future__ import annotations
import argparse
import math
import random
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Any
import numpy as np
import structlog
log = structlog.get_logger()
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
def _parse_state_line(line: str) -> dict[str, Any] | None:
"""Parse an ``S,…`` state line from the ESP32.
Format: S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
(7 comma-separated fields, firmware sends SI units)
"""
if not line.startswith("S,"):
return None
parts = line.split(",")
if len(parts) < 7:
return None
try:
return {
"timestamp_ms": int(parts[1]),
"motor_rad": float(parts[2]),
"motor_vel": float(parts[3]),
"pend_rad": float(parts[4]),
"pend_vel": float(parts[5]),
"motor_speed": int(parts[6]),
}
except (ValueError, IndexError):
return None
# ── Background serial reader ─────────────────────────────────────────
class _SerialReader:
"""Minimal background reader for the ESP32 serial stream.
Uses a sequence counter so ``read_blocking()`` guarantees it returns
a *new* state line (not a stale repeat). This keeps the capture
loop locked to the firmware's 50 Hz tick.
"""
def __init__(self, port: str, baud: int = 115200):
import serial as _serial
self._serial_mod = _serial
self.ser = _serial.Serial(port, baud, timeout=0.05)
time.sleep(2) # Wait for ESP32 boot.
self.ser.reset_input_buffer()
self._latest: dict[str, Any] = {}
self._seq: int = 0 # incremented on every new state line
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
self._running = True
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
self._thread.start()
def _reader_loop(self) -> None:
_debug_count = 0
while self._running:
try:
if self.ser.in_waiting:
line = (
self.ser.readline()
.decode("utf-8", errors="ignore")
.strip()
)
# Debug: log first 10 raw lines so we can see what the firmware sends.
if _debug_count < 10 and line:
log.info("serial_raw_line", line=repr(line), count=_debug_count)
_debug_count += 1
parsed = _parse_state_line(line)
if parsed is not None:
with self._cond:
self._latest = parsed
self._seq += 1
self._cond.notify_all()
else:
time.sleep(0.001)
except (OSError, self._serial_mod.SerialException):
log.critical("serial_lost")
break
def send(self, cmd: str) -> None:
try:
self.ser.write(f"{cmd}\n".encode())
except (OSError, self._serial_mod.SerialException):
log.critical("serial_send_failed", cmd=cmd)
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
"""Wait until a *new* state line arrives, then return it.
Uses a sequence counter to guarantee the returned state is
different from whatever was available before this call.
"""
with self._cond:
seq_before = self._seq
if not self._cond.wait_for(
lambda: self._seq > seq_before, timeout=timeout
):
return {} # timeout — no new data
return dict(self._latest)
def close(self) -> None:
self._running = False
self.send("H")
self.send("M0")
time.sleep(0.1)
self._thread.join(timeout=1.0)
self.ser.close()
# ── PRBS excitation signal ───────────────────────────────────────────
class _PRBSExcitation:
"""Random hold-value excitation with configurable amplitude and hold time.
At each call to ``__call__``, returns the current PWM value.
The value is held for a random duration (``hold_min````hold_max`` ms),
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
"""
def __init__(
self,
amplitude: int = 150,
hold_min_ms: int = 50,
hold_max_ms: int = 300,
):
self.amplitude = amplitude
self.hold_min_ms = hold_min_ms
self.hold_max_ms = hold_max_ms
self._current: int = 0
self._switch_time: float = 0.0
self._new_value()
def _new_value(self) -> None:
self._current = random.randint(-self.amplitude, self.amplitude)
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
self._switch_time = time.monotonic() + hold_ms / 1000.0
def __call__(self) -> int:
if time.monotonic() >= self._switch_time:
self._new_value()
return self._current
# ── Main capture loop ────────────────────────────────────────────────
def capture(
robot_path: str | Path,
port: str = "/dev/cu.usbserial-0001",
baud: int = 115200,
duration: float = 20.0,
amplitude: int = 150,
hold_min_ms: int = 50,
hold_max_ms: int = 300,
dt: float = 0.02,
motor_angle_limit_deg: float = 90.0,
) -> Path:
"""Run the capture procedure and return the path to the saved .npz file.
The capture loop is **stream-driven**: it blocks on each incoming
state line from the firmware (which arrives at 50 Hz), sends the
next motor command immediately, and records both.
Parameters
----------
robot_path : path to robot asset directory
port : serial port for ESP32
baud : baud rate
duration : capture duration in seconds
amplitude : max PWM magnitude for excitation
hold_min_ms / hold_max_ms : random hold time range (ms)
dt : nominal sample period for buffer sizing (seconds)
motor_angle_limit_deg : safety limit for motor angle
"""
robot_path = Path(robot_path).resolve()
max_motor_rad = math.radians(motor_angle_limit_deg) if motor_angle_limit_deg > 0 else 0.0
# Connect.
reader = _SerialReader(port, baud)
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
# Prepare recording buffers (generous headroom).
max_samples = int(duration / dt) + 500
rec_time = np.zeros(max_samples, dtype=np.float64)
rec_action = np.zeros(max_samples, dtype=np.float64)
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
# Start streaming.
reader.send("G")
time.sleep(0.1)
log.info(
"capture_starting",
port=port,
duration=duration,
amplitude=amplitude,
hold_range_ms=f"{hold_min_ms}{hold_max_ms}",
mode="stream-driven (firmware clock)",
)
idx = 0
pwm = 0
last_esp_ms = -1 # firmware timestamp of last recorded sample
esp_ms_origin: int | None = None # first firmware timestamp
no_data_count = 0 # consecutive timeouts with no data
t0 = time.monotonic() # host clock for duration check only
try:
while True:
# Block until the firmware sends the next state line (~20 ms).
# Timeout at 100 ms prevents hanging if the ESP32 disconnects.
state = reader.read_blocking(timeout=0.1)
if not state:
no_data_count += 1
if no_data_count == 30: # 3 seconds with no data
log.warning(
"no_data_received",
msg="No state lines from firmware after 3s. "
"Check: is the ESP32 powered? Is it running the right firmware? "
"Try pressing the RESET button.",
)
if no_data_count == 100: # 10 seconds
log.critical(
"no_data_timeout",
msg="No data for 10s — aborting capture.",
)
break
continue # no data yet — retry
no_data_count = 0
# Deduplicate: the firmware may send multiple state lines per
# tick (e.g. M-command echo + tick). Only record one sample
# per unique firmware timestamp.
esp_ms = state.get("timestamp_ms", 0)
if esp_ms == last_esp_ms:
continue
last_esp_ms = esp_ms
# Use firmware clock for time axis (avoids host serial jitter).
if esp_ms_origin is None:
esp_ms_origin = esp_ms
elapsed = (esp_ms - esp_ms_origin) / 1000.0
if elapsed >= duration:
break
# Get excitation PWM for the NEXT tick.
pwm = excitation()
# Safety: keep the arm well within its mechanical range.
# Firmware sends motor angle in radians — use directly.
motor_angle_rad = state.get("motor_rad", 0.0)
if max_motor_rad > 0:
ratio = motor_angle_rad / max_motor_rad # signed, -1..+1
abs_ratio = abs(ratio)
if abs_ratio > 0.90:
# Deep in the danger zone — force a strong return.
brake_strength = min(1.0, (abs_ratio - 0.90) / 0.10) # 0→1
brake_pwm = int(amplitude * (0.5 + 0.5 * brake_strength))
pwm = -brake_pwm if ratio > 0 else brake_pwm
elif abs_ratio > 0.70:
# Soft zone — only allow actions pointing back to centre.
if ratio > 0 and pwm > 0:
pwm = -abs(pwm)
elif ratio < 0 and pwm < 0:
pwm = abs(pwm)
# Send command immediately — it will take effect on the next tick.
reader.send(f"M{pwm}")
# Record this tick's state + the action the motor *actually*
# received. Firmware sends SI units — use directly.
motor_angle = state.get("motor_rad", 0.0)
motor_vel = state.get("motor_vel", 0.0)
pend_angle = state.get("pend_rad", 0.0)
pend_vel = state.get("pend_vel", 0.0)
# Firmware constrains to ±255; normalise to [-1, 1].
applied = state.get("motor_speed", 0)
action_norm = max(-255, min(255, applied)) / 255.0
if idx < max_samples:
rec_time[idx] = elapsed
rec_action[idx] = action_norm
rec_motor_angle[idx] = motor_angle
rec_motor_vel[idx] = motor_vel
rec_pend_angle[idx] = pend_angle
rec_pend_vel[idx] = pend_vel
idx += 1
else:
break # buffer full
# Progress (every 50 samples ≈ once per second at 50 Hz).
if idx % 50 == 0:
log.info(
"capture_progress",
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
samples=idx,
pwm=pwm,
)
finally:
reader.send("M0")
reader.close()
# Trim to actual sample count.
rec_time = rec_time[:idx]
rec_action = rec_action[:idx]
rec_motor_angle = rec_motor_angle[:idx]
rec_motor_vel = rec_motor_vel[:idx]
rec_pend_angle = rec_pend_angle[:idx]
rec_pend_vel = rec_pend_vel[:idx]
# Save.
recordings_dir = robot_path / "recordings"
recordings_dir.mkdir(exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = recordings_dir / f"capture_{stamp}.npz"
np.savez_compressed(
out_path,
time=rec_time,
action=rec_action,
motor_angle=rec_motor_angle,
motor_vel=rec_motor_vel,
pendulum_angle=rec_pend_angle,
pendulum_vel=rec_pend_vel,
)
log.info(
"capture_saved",
path=str(out_path),
samples=idx,
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
)
return out_path
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Capture a real-robot trajectory for system identification."
)
parser.add_argument(
"--robot-path",
type=str,
default="assets/rotary_cartpole",
help="Path to robot asset directory",
)
parser.add_argument(
"--port",
type=str,
default="/dev/cu.usbserial-0001",
help="Serial port for ESP32",
)
parser.add_argument("--baud", type=int, default=115200)
parser.add_argument(
"--duration", type=float, default=20.0, help="Capture duration (s)"
)
parser.add_argument(
"--amplitude", type=int, default=150,
help="Max PWM magnitude for excitation (0-255)",
)
parser.add_argument(
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
)
parser.add_argument(
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
)
parser.add_argument(
"--dt", type=float, default=0.02, help="Nominal sample period for buffer sizing (s)"
)
parser.add_argument(
"--motor-angle-limit", type=float, default=90.0,
help="Motor angle safety limit in degrees (0 = disabled)",
)
args = parser.parse_args()
capture(
robot_path=args.robot_path,
port=args.port,
baud=args.baud,
duration=args.duration,
amplitude=args.amplitude,
hold_min_ms=args.hold_min_ms,
hold_max_ms=args.hold_max_ms,
dt=args.dt,
motor_angle_limit_deg=args.motor_angle_limit,
)
if __name__ == "__main__":
main()

182
src/sysid/export.py Normal file
View File

@@ -0,0 +1,182 @@
"""Export tuned parameters to URDF and robot.yaml files.
Reads the original files, injects the optimised parameter values,
and writes ``rotary_cartpole_tuned.urdf`` + ``robot_tuned.yaml``
alongside the originals in the robot asset directory.
"""
from __future__ import annotations
import copy
import xml.etree.ElementTree as ET
from pathlib import Path
import structlog
import yaml
from src.sysid._urdf import patch_link_inertials
log = structlog.get_logger()
def export_tuned_files(
robot_path: str | Path,
params: dict[str, float],
motor_params: dict[str, float] | None = None,
) -> tuple[Path, Path]:
"""Write tuned URDF and robot.yaml files.
Parameters
----------
robot_path : robot asset directory (contains robot.yaml + *.urdf)
params : dict of parameter name → tuned value (the optimised set)
motor_params : locked motor parameters merged underneath ``params``
(``params`` wins on conflicts) so the exported YAML always has a
complete motor model
Returns
-------
(tuned_urdf_path, tuned_robot_yaml_path)
"""
robot_path = Path(robot_path).resolve()
if motor_params:
params = {**motor_params, **params}
# ── Load originals ───────────────────────────────────────────
robot_yaml_path = robot_path / "robot.yaml"
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
urdf_path = robot_path / robot_cfg["urdf"]
# ── Tune URDF ────────────────────────────────────────────────
tree = ET.parse(urdf_path)
patch_link_inertials(tree.getroot(), params)
# Write tuned URDF.
tuned_urdf_name = urdf_path.stem + "_tuned" + urdf_path.suffix
tuned_urdf_path = robot_path / tuned_urdf_name
# Preserve the XML declaration and original formatting as much as possible.
ET.indent(tree, space=" ")
tree.write(str(tuned_urdf_path), xml_declaration=True, encoding="unicode")
log.info("tuned_urdf_written", path=str(tuned_urdf_path))
# ── Tune robot.yaml ──────────────────────────────────────────
tuned_cfg = copy.deepcopy(robot_cfg)
# Point to the tuned URDF.
tuned_cfg["urdf"] = tuned_urdf_name
# Update actuator parameters — full asymmetric motor model.
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
act = tuned_cfg["actuators"][0]
# Asymmetric gear, damping, deadzone, frictionloss as [pos, neg].
act["gear"] = [
round(params.get("actuator_gear_pos", 0.424), 6),
round(params.get("actuator_gear_neg", 0.425), 6),
]
act["damping"] = [
round(params.get("motor_damping_pos", 0.002), 6),
round(params.get("motor_damping_neg", 0.015), 6),
]
act["deadzone"] = [
round(params.get("motor_deadzone_pos", 0.141), 6),
round(params.get("motor_deadzone_neg", 0.078), 6),
]
act["frictionloss"] = [
round(params.get("motor_frictionloss_pos", 0.057), 6),
round(params.get("motor_frictionloss_neg", 0.053), 6),
]
if "actuator_filter_tau" in params:
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
# Stribeck friction and action bias.
if "stribeck_friction_boost" in params:
act["stribeck_friction_boost"] = round(params["stribeck_friction_boost"], 6)
if "stribeck_vel" in params:
act["stribeck_vel"] = round(params["stribeck_vel"], 6)
if "action_bias" in params:
act["action_bias"] = round(params["action_bias"], 6)
# ctrl_range from ctrl_limit parameter.
if "ctrl_limit" in params:
lim = round(params["ctrl_limit"], 6)
act["ctrl_range"] = [-lim, lim]
# Update joint overrides.
if "joints" not in tuned_cfg:
tuned_cfg["joints"] = {}
if "motor_joint" not in tuned_cfg["joints"]:
tuned_cfg["joints"]["motor_joint"] = {}
mj = tuned_cfg["joints"]["motor_joint"]
if "motor_armature" in params:
mj["armature"] = round(params["motor_armature"], 6)
# Frictionloss/damping = 0 in MuJoCo (motor model handles via qfrc_applied).
mj["frictionloss"] = 0.0
if "pendulum_joint" not in tuned_cfg["joints"]:
tuned_cfg["joints"]["pendulum_joint"] = {}
pj = tuned_cfg["joints"]["pendulum_joint"]
if "pendulum_damping" in params:
pj["damping"] = round(params["pendulum_damping"], 6)
if "pendulum_frictionloss" in params:
pj["frictionloss"] = round(params["pendulum_frictionloss"], 6)
# Write tuned robot.yaml.
tuned_yaml_path = robot_path / "robot_tuned.yaml"
# Add a header comment.
header = (
"# Tuned robot config — generated by src.sysid.optimize\n"
"# Original: robot.yaml\n"
"# Run `python -m src.sysid.visualize` to compare real vs sim.\n\n"
)
tuned_yaml_path.write_text(
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
)
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
return tuned_urdf_path, tuned_yaml_path
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
import argparse
import json
parser = argparse.ArgumentParser(
description="Export tuned URDF + robot.yaml from sysid results."
)
parser.add_argument(
"--robot-path", type=str, default="assets/rotary_cartpole",
help="Path to robot asset directory",
)
parser.add_argument(
"--result", type=str, default=None,
help="Path to sysid_result.json (auto-detected if omitted)",
)
args = parser.parse_args()
robot_path = Path(args.robot_path).resolve()
if args.result:
result_path = Path(args.result)
else:
result_path = robot_path / "sysid_result.json"
if not result_path.exists():
raise FileNotFoundError(f"Result file not found: {result_path}")
result = json.loads(result_path.read_text())
export_tuned_files(
robot_path=args.robot_path,
params=result["best_params"],
)
if __name__ == "__main__":
main()

186
src/sysid/motor/export.py Normal file
View File

@@ -0,0 +1,186 @@
"""Export tuned motor parameters to MJCF and robot.yaml files.
Reads the original motor.xml and robot.yaml, patches with optimised
parameter values, and writes motor_tuned.xml + robot_tuned.yaml.
Usage:
python -m src.sysid.motor.export \
--asset-path assets/motor \
--result assets/motor/motor_sysid_result.json
"""
from __future__ import annotations
import argparse
import copy
import json
import xml.etree.ElementTree as ET
from pathlib import Path
import structlog
import yaml
log = structlog.get_logger()
_DEFAULT_ASSET = "assets/motor"
def export_tuned_files(
asset_path: str | Path,
params: dict[str, float],
) -> tuple[Path, Path]:
"""Write tuned MJCF and robot.yaml files.
Returns (tuned_mjcf_path, tuned_robot_yaml_path).
"""
asset_path = Path(asset_path).resolve()
robot_yaml_path = asset_path / "robot.yaml"
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
mjcf_path = asset_path / robot_cfg["mjcf"]
# ── Tune MJCF ────────────────────────────────────────────────
tree = ET.parse(str(mjcf_path))
root = tree.getroot()
# Actuator — use average gear for the MJCF model.
gear_pos = params.get("actuator_gear_pos", params.get("actuator_gear"))
gear_neg = params.get("actuator_gear_neg", params.get("actuator_gear"))
gear_avg = None
if gear_pos is not None and gear_neg is not None:
gear_avg = (gear_pos + gear_neg) / 2.0
elif gear_pos is not None:
gear_avg = gear_pos
filter_tau = params.get("actuator_filter_tau")
for act_el in root.iter("general"):
if act_el.get("name") == "motor":
if gear_avg is not None:
act_el.set("gear", str(gear_avg))
if filter_tau is not None:
if filter_tau > 0:
act_el.set("dyntype", "filter")
act_el.set("dynprm", str(filter_tau))
else:
act_el.set("dyntype", "none")
# Joint — average damping & friction for MJCF (asymmetry in runtime).
fl_pos = params.get("motor_frictionloss_pos", params.get("motor_frictionloss"))
fl_neg = params.get("motor_frictionloss_neg", params.get("motor_frictionloss"))
fl_avg = None
if fl_pos is not None and fl_neg is not None:
fl_avg = (fl_pos + fl_neg) / 2.0
elif fl_pos is not None:
fl_avg = fl_pos
damp_pos = params.get("motor_damping_pos", params.get("motor_damping"))
damp_neg = params.get("motor_damping_neg", params.get("motor_damping"))
damp_avg = None
if damp_pos is not None and damp_neg is not None:
damp_avg = (damp_pos + damp_neg) / 2.0
elif damp_pos is not None:
damp_avg = damp_pos
for jnt in root.iter("joint"):
if jnt.get("name") == "motor_joint":
if damp_avg is not None:
jnt.set("damping", str(damp_avg))
if "motor_armature" in params:
jnt.set("armature", str(params["motor_armature"]))
if fl_avg is not None:
jnt.set("frictionloss", str(fl_avg))
# Rotor mass.
if "rotor_mass" in params:
for geom in root.iter("geom"):
if geom.get("name") == "rotor_disk":
geom.set("mass", str(params["rotor_mass"]))
# Write tuned MJCF.
tuned_mjcf_name = mjcf_path.stem + "_tuned" + mjcf_path.suffix
tuned_mjcf_path = asset_path / tuned_mjcf_name
ET.indent(tree, space=" ")
tree.write(str(tuned_mjcf_path), xml_declaration=True, encoding="unicode")
log.info("tuned_mjcf_written", path=str(tuned_mjcf_path))
# ── Tune robot.yaml ──────────────────────────────────────────
tuned_cfg = copy.deepcopy(robot_cfg)
tuned_cfg["mjcf"] = tuned_mjcf_name
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
act = tuned_cfg["actuators"][0]
if gear_avg is not None:
act["gear"] = round(gear_avg, 6)
if "actuator_filter_tau" in params:
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
if "motor_damping" in params:
act["damping"] = round(params["motor_damping"], 6)
if "joints" not in tuned_cfg:
tuned_cfg["joints"] = {}
if "motor_joint" not in tuned_cfg["joints"]:
tuned_cfg["joints"]["motor_joint"] = {}
mj = tuned_cfg["joints"]["motor_joint"]
if "motor_armature" in params:
mj["armature"] = round(params["motor_armature"], 6)
if fl_avg is not None:
mj["frictionloss"] = round(fl_avg, 6)
# Asymmetric / hardware-realism / nonlinear parameters.
realism = {}
for key in [
"actuator_gear_pos", "actuator_gear_neg",
"motor_damping_pos", "motor_damping_neg",
"motor_frictionloss_pos", "motor_frictionloss_neg",
"motor_deadzone_pos", "motor_deadzone_neg",
"action_bias",
"viscous_quadratic", "back_emf_gain",
"stribeck_friction_boost", "stribeck_vel",
"gearbox_backlash",
]:
if key in params:
realism[key] = round(params[key], 6)
if realism:
tuned_cfg["hardware_realism"] = realism
tuned_yaml_path = asset_path / "robot_tuned.yaml"
header = (
"# Tuned motor config — generated by src.sysid.motor.optimize\n"
"# Original: robot.yaml\n\n"
)
tuned_yaml_path.write_text(
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
)
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
return tuned_mjcf_path, tuned_yaml_path
# ── CLI ──────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Export tuned motor parameters to MJCF + robot.yaml."
)
parser.add_argument("--asset-path", type=str, default=_DEFAULT_ASSET)
parser.add_argument(
"--result", type=str, default=None,
help="Path to motor_sysid_result.json (auto-detected if omitted)",
)
args = parser.parse_args()
asset_path = Path(args.asset_path).resolve()
if args.result:
result_path = Path(args.result)
else:
result_path = asset_path / "motor_sysid_result.json"
if not result_path.exists():
raise FileNotFoundError(f"Result file not found: {result_path}")
result = json.loads(result_path.read_text())
params = result["best_params"]
export_tuned_files(asset_path=args.asset_path, params=params)
if __name__ == "__main__":
main()

580
src/sysid/optimize.py Normal file
View File

@@ -0,0 +1,580 @@
"""CMA-ES optimiser — fit simulation parameters to a real-robot recording.
Minimises the trajectory-matching cost between a MuJoCo rollout and a
recorded real-robot sequence. Uses the ``cmaes`` package (pure-Python
CMA-ES with native box-constraint support).
Motor parameters are **locked** from the motor-only sysid — only
pendulum/arm inertial parameters, joint dynamics, and ctrl_limit are
optimised. Velocities are optionally preprocessed with Savitzky-Golay
differentiation for cleaner targets.
Usage:
python -m src.sysid.optimize \
--robot-path assets/rotary_cartpole \
--recording assets/rotary_cartpole/recordings/capture_20260314_000435.npz
# Shorter run for testing:
python -m src.sysid.optimize \
--robot-path assets/rotary_cartpole \
--recording <file>.npz \
--max-generations 10 --population-size 8
"""
from __future__ import annotations
import argparse
import json
import time
from datetime import datetime
from pathlib import Path
import numpy as np
import structlog
from src.sysid.rollout import (
LOCKED_MOTOR_PARAMS,
PARAM_SETS,
ROTARY_CARTPOLE_PARAMS,
ParamSpec,
bounds_arrays,
defaults_vector,
params_to_dict,
rollout,
windowed_rollout,
)
log = structlog.get_logger()
# ── Velocity preprocessing ───────────────────────────────────────────
def _preprocess_recording(
recording: dict[str, np.ndarray],
preprocess_vel: bool = True,
sg_window: int = 7,
sg_polyorder: int = 3,
) -> dict[str, np.ndarray]:
"""Optionally recompute velocities using Savitzky-Golay differentiation.
Applies SG filtering to both motor_vel and pendulum_vel, replacing
the noisy firmware finite-difference velocities with smooth
analytical derivatives of the (clean) angle signals.
"""
if not preprocess_vel:
return recording
from scipy.signal import savgol_filter
rec = dict(recording)
times = rec["time"]
dt = float(np.mean(np.diff(times)))
# Motor velocity.
rec["motor_vel_raw"] = rec["motor_vel"].copy()
rec["motor_vel"] = savgol_filter(
rec["motor_angle"],
window_length=sg_window,
polyorder=sg_polyorder,
deriv=1,
delta=dt,
)
# Pendulum velocity.
rec["pendulum_vel_raw"] = rec["pendulum_vel"].copy()
rec["pendulum_vel"] = savgol_filter(
rec["pendulum_angle"],
window_length=sg_window,
polyorder=sg_polyorder,
deriv=1,
delta=dt,
)
motor_noise = np.std(rec["motor_vel_raw"] - rec["motor_vel"])
pend_noise = np.std(rec["pendulum_vel_raw"] - rec["pendulum_vel"])
log.info(
"velocity_preprocessed",
method="savgol",
sg_window=sg_window,
sg_polyorder=sg_polyorder,
dt_ms=f"{dt*1000:.1f}",
motor_noise_std=f"{motor_noise:.3f} rad/s",
pend_noise_std=f"{pend_noise:.3f} rad/s",
)
return rec
# ── Cost function ────────────────────────────────────────────────────
def _angle_diff(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Shortest signed angle difference, handling wrapping."""
return np.arctan2(np.sin(a - b), np.cos(a - b))
def _check_inertia_valid(params: dict[str, float]) -> bool:
"""Quick reject: pendulum inertia tensor must be positive-definite."""
ixx = params.get("pendulum_ixx", 6.16e-06)
iyy = params.get("pendulum_iyy", 6.16e-06)
izz = params.get("pendulum_izz", 1.23e-05)
ixy = params.get("pendulum_ixy", 6.10e-06)
det_xy = ixx * iyy - ixy * ixy
return det_xy > 0 and ixx > 0 and iyy > 0 and izz > 0
def _compute_trajectory_cost(
sim: dict[str, np.ndarray],
recording: dict[str, np.ndarray],
pos_weight: float = 1.0,
vel_weight: float = 0.1,
pendulum_scale: float = 3.0,
vel_outlier_threshold: float = 20.0,
) -> float:
"""Weighted MSE between sim and real trajectories.
pendulum_scale multiplies the pendulum terms relative to motor terms.
Samples where the *real* pendulum velocity exceeds
``vel_outlier_threshold`` (rad/s) are excluded from the velocity
terms. These are encoder-wrap artefacts (pendulum swinging past
vertical) that no simulator can reproduce.
"""
motor_err = _angle_diff(sim["motor_angle"], recording["motor_angle"])
pend_err = _angle_diff(sim["pendulum_angle"], recording["pendulum_angle"])
motor_vel_err = sim["motor_vel"] - recording["motor_vel"]
pend_vel_err = sim["pendulum_vel"] - recording["pendulum_vel"]
# Mask out encoder-wrap velocity spikes so the optimizer doesn't
# waste capacity fitting artefacts.
valid = np.abs(recording["pendulum_vel"]) < vel_outlier_threshold
if valid.sum() < len(valid):
motor_vel_err = motor_vel_err[valid]
pend_vel_err = pend_vel_err[valid]
return float(
pos_weight * np.mean(motor_err**2)
+ pos_weight * pendulum_scale * np.mean(pend_err**2)
+ vel_weight * np.mean(motor_vel_err**2)
+ vel_weight * pendulum_scale * np.mean(pend_vel_err**2)
)
def cost_function(
params_vec: np.ndarray,
recording: dict[str, np.ndarray],
robot_path: Path,
specs: list[ParamSpec],
sim_dt: float = 0.002,
substeps: int = 10,
pos_weight: float = 1.0,
vel_weight: float = 0.1,
pendulum_scale: float = 3.0,
window_duration: float = 0.5,
motor_params: dict[str, float] | None = None,
) -> float:
"""Compute trajectory-matching cost for a candidate parameter vector.
Uses **multiple-shooting** (windowed rollout): the recording is split
into short windows (default 0.5 s). Each window is initialised from
the real qpos/qvel, so early errors dont compound across the full
trajectory. This gives a much smoother cost landscape for CMA-ES.
Set ``window_duration=0`` to fall back to the original open-loop
single-shot rollout (not recommended).
"""
params = params_to_dict(params_vec, specs)
if not _check_inertia_valid(params):
return 1e6
try:
if window_duration > 0:
sim = windowed_rollout(
robot_path=robot_path,
params=params,
recording=recording,
window_duration=window_duration,
sim_dt=sim_dt,
substeps=substeps,
motor_params=motor_params,
)
else:
sim = rollout(
robot_path=robot_path,
params=params,
actions=recording["action"],
sim_dt=sim_dt,
substeps=substeps,
motor_params=motor_params,
)
except Exception as exc:
log.warning("rollout_failed", error=str(exc))
return 1e6
# Check for NaN in sim output.
for key in ("motor_angle", "motor_vel", "pendulum_angle", "pendulum_vel"):
if np.any(~np.isfinite(sim[key])):
return 1e6
return _compute_trajectory_cost(sim, recording, pos_weight, vel_weight, pendulum_scale)
# ── Parallel evaluation helper (module-level for pickling) ───────────
# Shared state set by the parent process before spawning workers.
_par_recording: dict[str, np.ndarray] = {}
_par_robot_path: Path = Path(".")
_par_specs: list[ParamSpec] = []
_par_kwargs: dict = {}
def _init_worker(recording, robot_path, specs, kwargs):
"""Initialiser run once per worker process."""
global _par_recording, _par_robot_path, _par_specs, _par_kwargs
_par_recording = recording
_par_robot_path = robot_path
_par_specs = specs
_par_kwargs = kwargs
def _eval_candidate(x_natural: np.ndarray) -> float:
"""Evaluate a single candidate — called in worker processes."""
return cost_function(
x_natural,
_par_recording,
_par_robot_path,
_par_specs,
**_par_kwargs,
)
# ── CMA-ES optimisation loop ────────────────────────────────────────
def optimize(
robot_path: str | Path,
recording_path: str | Path,
specs: list[ParamSpec] | None = None,
sigma0: float = 0.3,
population_size: int = 20,
max_generations: int = 1000,
sim_dt: float = 0.002,
substeps: int = 10,
pos_weight: float = 1.0,
vel_weight: float = 0.1,
pendulum_scale: float = 3.0,
window_duration: float = 0.5,
seed: int = 42,
preprocess_vel: bool = True,
sg_window: int = 7,
sg_polyorder: int = 3,
) -> dict:
"""Run CMA-ES optimisation and return results.
Motor parameters are locked from the motor-only sysid.
Only pendulum/arm parameters are optimised.
Returns a dict with:
best_params: dict[str, float]
best_cost: float
history: list of (generation, best_cost) tuples
recording: str (path used)
specs: list of param names
motor_params: dict of locked motor params
"""
from cmaes import CMA
robot_path = Path(robot_path).resolve()
recording_path = Path(recording_path).resolve()
if specs is None:
specs = ROTARY_CARTPOLE_PARAMS
motor_params = dict(LOCKED_MOTOR_PARAMS)
log.info(
"motor_params_locked",
n_params=len(motor_params),
gear_avg=f"{(motor_params['actuator_gear_pos'] + motor_params['actuator_gear_neg']) / 2:.4f}",
)
# Load recording.
recording = dict(np.load(recording_path))
# Preprocessing: SG velocity recomputation.
recording = _preprocess_recording(
recording,
preprocess_vel=preprocess_vel,
sg_window=sg_window,
sg_polyorder=sg_polyorder,
)
n_samples = len(recording["time"])
duration = recording["time"][-1] - recording["time"][0]
n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
log.info(
"recording_loaded",
path=str(recording_path),
samples=n_samples,
duration=f"{duration:.1f}s",
window_duration=f"{window_duration}s",
n_windows=n_windows,
)
# Initial point (defaults) — normalised to [0, 1] for CMA-ES.
lo, hi = bounds_arrays(specs)
x0 = defaults_vector(specs)
# Normalise to [0, 1] for the optimizer (better conditioned).
span = hi - lo
span[span == 0] = 1.0 # avoid division by zero
def to_normed(x: np.ndarray) -> np.ndarray:
return (x - lo) / span
def from_normed(x_n: np.ndarray) -> np.ndarray:
return x_n * span + lo
x0_normed = to_normed(x0)
bounds_normed = np.column_stack(
[np.zeros(len(specs)), np.ones(len(specs))]
)
optimizer = CMA(
mean=x0_normed,
sigma=sigma0,
bounds=bounds_normed,
population_size=population_size,
seed=seed,
)
best_cost = float("inf")
best_params_vec = x0.copy()
history: list[tuple[int, float]] = []
log.info(
"cmaes_starting",
n_params=len(specs),
population=population_size,
max_gens=max_generations,
sigma0=sigma0,
)
t0 = time.monotonic()
# ── Parallel evaluation setup ────────────────────────────────
# Each candidate is independent — evaluate them in parallel using
# a process pool. Falls back to sequential if n_workers=1.
import multiprocessing as mp
n_workers = max(1, mp.cpu_count() - 1) # leave 1 core free
eval_kwargs = dict(
sim_dt=sim_dt,
substeps=substeps,
pos_weight=pos_weight,
vel_weight=vel_weight,
pendulum_scale=pendulum_scale,
window_duration=window_duration,
motor_params=motor_params,
)
log.info("parallel_workers", n_workers=n_workers)
# Create a persistent pool (avoids per-generation fork overhead).
pool = None
if n_workers > 1:
pool = mp.Pool(
n_workers,
initializer=_init_worker,
initargs=(recording, robot_path, specs, eval_kwargs),
)
for gen in range(max_generations):
# Ask all candidates first.
candidates_normed = []
candidates_natural = []
for _ in range(optimizer.population_size):
x_normed = optimizer.ask()
x_natural = from_normed(x_normed)
x_natural = np.clip(x_natural, lo, hi)
candidates_normed.append(x_normed)
candidates_natural.append(x_natural)
# Evaluate in parallel.
if pool is not None:
costs = pool.map(_eval_candidate, candidates_natural)
else:
costs = [cost_function(
x, recording, robot_path, specs, **eval_kwargs
) for x in candidates_natural]
solutions = list(zip(candidates_normed, costs))
for x_natural, c in zip(candidates_natural, costs):
if c < best_cost:
best_cost = c
best_params_vec = x_natural.copy()
optimizer.tell(solutions)
history.append((gen, best_cost))
elapsed = time.monotonic() - t0
if gen % 5 == 0 or gen == max_generations - 1:
log.info(
"cmaes_generation",
gen=gen,
best_cost=f"{best_cost:.6f}",
elapsed=f"{elapsed:.1f}s",
gen_best=f"{min(c for _, c in solutions):.6f}",
)
total_time = time.monotonic() - t0
# Clean up process pool.
if pool is not None:
pool.close()
pool.join()
best_params = params_to_dict(best_params_vec, specs)
log.info(
"cmaes_finished",
best_cost=f"{best_cost:.6f}",
total_time=f"{total_time:.1f}s",
evaluations=max_generations * population_size,
)
# Log parameter comparison.
defaults = params_to_dict(defaults_vector(specs), specs)
for name in best_params:
d = defaults[name]
b = best_params[name]
change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
log.info(
"param_result",
name=name,
default=f"{d:.6g}",
tuned=f"{b:.6g}",
change=f"{change_pct:+.1f}%",
)
return {
"best_params": best_params,
"best_cost": best_cost,
"history": history,
"recording": str(recording_path),
"param_names": [s.name for s in specs],
"defaults": {s.name: s.default for s in specs},
"motor_params": motor_params,
"preprocess_vel": preprocess_vel,
"timestamp": datetime.now().isoformat(),
}
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Fit simulation parameters to a real-robot recording (CMA-ES)."
)
parser.add_argument(
"--robot-path",
type=str,
default="assets/rotary_cartpole",
help="Path to robot asset directory",
)
parser.add_argument(
"--recording",
type=str,
required=True,
help="Path to .npz recording file",
)
parser.add_argument("--sigma0", type=float, default=0.3)
parser.add_argument("--population-size", type=int, default=20)
parser.add_argument("--max-generations", type=int, default=200)
parser.add_argument("--sim-dt", type=float, default=0.002)
parser.add_argument("--substeps", type=int, default=10)
parser.add_argument("--pos-weight", type=float, default=1.0)
parser.add_argument("--vel-weight", type=float, default=0.1)
parser.add_argument("--pendulum-scale", type=float, default=3.0,
help="Multiplier for pendulum terms relative to motor (default 3.0)")
parser.add_argument(
"--window-duration",
type=float,
default=0.5,
help="Shooting window length in seconds (0 = open-loop, default 0.5)",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--no-export",
action="store_true",
help="Skip exporting tuned files (results JSON only)",
)
parser.add_argument(
"--no-preprocess-vel",
action="store_true",
help="Skip Savitzky-Golay velocity preprocessing",
)
parser.add_argument("--sg-window", type=int, default=7,
help="Savitzky-Golay window length (odd, default 7)")
parser.add_argument("--sg-polyorder", type=int, default=3,
help="Savitzky-Golay polynomial order (default 3)")
parser.add_argument(
"--param-set",
type=str,
default="full",
choices=list(PARAM_SETS.keys()),
help="Parameter set to optimize: 'reduced' (6 params, fast) or 'full' (15 params)",
)
args = parser.parse_args()
specs = PARAM_SETS[args.param_set]
result = optimize(
robot_path=args.robot_path,
recording_path=args.recording,
specs=specs,
sigma0=args.sigma0,
population_size=args.population_size,
max_generations=args.max_generations,
sim_dt=args.sim_dt,
substeps=args.substeps,
pos_weight=args.pos_weight,
vel_weight=args.vel_weight,
pendulum_scale=args.pendulum_scale,
window_duration=args.window_duration,
seed=args.seed,
preprocess_vel=not args.no_preprocess_vel,
sg_window=args.sg_window,
sg_polyorder=args.sg_polyorder,
)
# Save results JSON.
robot_path = Path(args.robot_path).resolve()
result_path = robot_path / "sysid_result.json"
# Convert numpy types for JSON serialisation.
result_json = {
k: v for k, v in result.items() if k != "history"
}
result_json["history_summary"] = {
"first_cost": result["history"][0][1] if result["history"] else None,
"final_cost": result["history"][-1][1] if result["history"] else None,
"generations": len(result["history"]),
}
result_path.write_text(json.dumps(result_json, indent=2, default=str))
log.info("results_saved", path=str(result_path))
# Export tuned files unless --no-export.
if not args.no_export:
from src.sysid.export import export_tuned_files
export_tuned_files(
robot_path=args.robot_path,
params=result["best_params"],
motor_params=result.get("motor_params"),
)
if __name__ == "__main__":
main()

425
src/sysid/rollout.py Normal file
View File

@@ -0,0 +1,425 @@
"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
Given a parameter vector and a recorded action sequence, builds a MuJoCo
model with overridden physics parameters, replays the actions, and returns
the simulated trajectory for comparison with the real recording.
This module is the inner loop of the CMA-ES optimizer: it is called once
per candidate parameter vector per generation.
Motor parameters are **locked** from the unified sysid result.
The optimizer only fits
pendulum/arm inertial parameters, pendulum joint dynamics, and
``ctrl_limit``. The asymmetric motor model (bias, deadzone, gear
compensation, Coulomb + Stribeck friction, viscous damping) is applied
via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()`` —
the same code the training runners use, so sim == sysid by construction.
"""
from __future__ import annotations
import dataclasses
import os
import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path
import mujoco
import numpy as np
import yaml
from src.core.robot import ActuatorConfig, JointConfig, RobotConfig
from src.runners.mujoco import ActuatorLimits, load_mujoco_model
from src.sysid._urdf import patch_link_inertials
# ── Locked motor parameters (from the unified sysid) ────────────────
# These are FIXED and not optimised. They come from the unified
# 28-param sysid run (assets/rotary_cartpole/sysid_result.json,
# cost 0.925) — Stribeck friction + action bias + ~96 ms motor lag.
LOCKED_MOTOR_PARAMS: dict[str, float] = {
"actuator_gear_pos": 0.846499,
"actuator_gear_neg": 1.183733,
"actuator_filter_tau": 0.096263,
"motor_damping_pos": 0.013165,
"motor_damping_neg": 0.015452,
"motor_armature": 0.001676,
"motor_frictionloss_pos": 0.014244,
"motor_frictionloss_neg": 0.001005,
"stribeck_friction_boost": 0.068594,
"stribeck_vel": 5.279594,
"motor_deadzone_pos": 0.181097,
"motor_deadzone_neg": 0.202072,
"action_bias": 0.056566,
}
# ── Tunable parameter specification ──────────────────────────────────
@dataclasses.dataclass
class ParamSpec:
"""Specification for a single tunable parameter."""
name: str
default: float
lower: float
upper: float
log_scale: bool = False # optimise in log-space (masses, inertias)
# Pendulum sysid parameters — motor params are LOCKED (not here).
# Order matters: the optimizer maps a flat vector to these specs.
# Defaults are from the URDF exported by Fusion 360.
ROTARY_CARTPOLE_PARAMS: list[ParamSpec] = [
# ── Arm link (URDF) ──────────────────────────────────────────
ParamSpec("arm_mass", 0.02110, 0.005, 0.08, log_scale=True),
ParamSpec("arm_com_x", -0.00710, -0.03, 0.03),
ParamSpec("arm_com_y", 0.00085, -0.02, 0.02),
ParamSpec("arm_com_z", 0.00795, -0.02, 0.02),
# ── Pendulum link (URDF) ─────────────────────────────────────
ParamSpec("pendulum_mass", 0.03937, 0.010, 0.10, log_scale=True),
ParamSpec("pendulum_com_x", 0.06025, 0.01, 0.15),
ParamSpec("pendulum_com_y", -0.07602, -0.20, 0.0),
ParamSpec("pendulum_com_z", -0.00346, -0.05, 0.05),
ParamSpec("pendulum_ixx", 6.20e-05, 1e-07, 1e-03, log_scale=True),
ParamSpec("pendulum_iyy", 3.70e-05, 1e-07, 1e-03, log_scale=True),
ParamSpec("pendulum_izz", 7.83e-05, 1e-07, 1e-03, log_scale=True),
ParamSpec("pendulum_ixy", -6.93e-06, -1e-03, 1e-03),
# ── Pendulum joint dynamics ──────────────────────────────────
ParamSpec("pendulum_damping", 0.0001, 1e-6, 0.05, log_scale=True),
ParamSpec("pendulum_frictionloss", 0.0001, 1e-6, 0.05, log_scale=True),
# ── Hardware realism (control pipeline) ────────────────────
ParamSpec("ctrl_limit", 0.588, 0.45, 0.70), # MAX_MOTOR_SPEED / 255
]
# Extended set: full params + motor armature (compensates for the
# motor-only sysid having captured arm/pendulum loading in armature).
EXTENDED_PARAMS: list[ParamSpec] = ROTARY_CARTPOLE_PARAMS + [
ParamSpec("motor_armature", 0.00277, 0.0005, 0.02, log_scale=True),
]
# Reduced set: only the 6 most impactful pendulum parameters.
# Good for a fast first pass — converges in ~50 generations.
REDUCED_PARAMS: list[ParamSpec] = [
ParamSpec("pendulum_mass", 0.03937, 0.010, 0.10, log_scale=True),
ParamSpec("pendulum_com_x", 0.06025, 0.01, 0.15),
ParamSpec("pendulum_com_y", -0.07602, -0.20, 0.0),
ParamSpec("pendulum_ixx", 6.20e-05, 1e-07, 1e-03, log_scale=True),
ParamSpec("pendulum_damping", 0.0001, 1e-6, 0.05, log_scale=True),
ParamSpec("pendulum_frictionloss", 0.0001, 1e-6, 0.05, log_scale=True),
]
PARAM_SETS: dict[str, list[ParamSpec]] = {
"full": ROTARY_CARTPOLE_PARAMS,
"extended": EXTENDED_PARAMS,
"reduced": REDUCED_PARAMS,
}
def params_to_dict(
values: np.ndarray, specs: list[ParamSpec] | None = None
) -> dict[str, float]:
"""Convert a flat parameter vector to a named dict."""
if specs is None:
specs = ROTARY_CARTPOLE_PARAMS
return {s.name: float(values[i]) for i, s in enumerate(specs)}
def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
"""Return the default parameter vector (in natural space)."""
if specs is None:
specs = ROTARY_CARTPOLE_PARAMS
return np.array([s.default for s in specs], dtype=np.float64)
def bounds_arrays(
specs: list[ParamSpec] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Return (lower, upper) bound arrays."""
if specs is None:
specs = ROTARY_CARTPOLE_PARAMS
lo = np.array([s.lower for s in specs], dtype=np.float64)
hi = np.array([s.upper for s in specs], dtype=np.float64)
return lo, hi
# ── MuJoCo model building with parameter overrides ──────────────────
def _build_model(
robot_path: Path,
params: dict[str, float],
motor_params: dict[str, float] | None = None,
) -> tuple[mujoco.MjModel, ActuatorConfig]:
"""Build a MuJoCo model with sysid overrides.
Returns (model, actuator) — use ``actuator.transform_ctrl()`` and
``actuator.compute_motor_force()`` in the rollout loop.
"""
if motor_params is None:
motor_params = LOCKED_MOTOR_PARAMS
robot_path = Path(robot_path).resolve()
# ── Patch URDF inertial parameters to a temp file ────────────
robot_yaml = yaml.safe_load((robot_path / "robot.yaml").read_text())
urdf_path = robot_path / robot_yaml["urdf"]
tree = ET.parse(urdf_path)
patch_link_inertials(tree.getroot(), params)
fd, tmp_urdf = tempfile.mkstemp(
suffix=".urdf", prefix="_sysid_", dir=str(robot_path),
)
os.close(fd)
tmp_urdf_path = Path(tmp_urdf)
tree.write(str(tmp_urdf_path), xml_declaration=True, encoding="unicode")
# ── Build RobotConfig with full motor sysid values ───────────
gear_pos = motor_params.get("actuator_gear_pos", 0.424182)
gear_neg = motor_params.get("actuator_gear_neg", 0.425031)
motor_armature = params.get(
"motor_armature",
motor_params.get("motor_armature", 0.00277342),
)
pend_damping = params.get("pendulum_damping", 0.0001)
pend_frictionloss = params.get("pendulum_frictionloss", 0.0001)
act_cfg = robot_yaml["actuators"][0]
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
# The fitted ctrl_limit overrides the YAML ctrl_range so the rollout
# saturates at exactly the identified PWM bound.
if "ctrl_limit" in params:
ctrl_lo, ctrl_hi = -params["ctrl_limit"], params["ctrl_limit"]
actuator = ActuatorConfig(
joint=act_cfg["joint"],
type="motor",
gear=(gear_pos, gear_neg),
ctrl_range=(ctrl_lo, ctrl_hi),
deadzone=(
motor_params.get("motor_deadzone_pos", 0.181),
motor_params.get("motor_deadzone_neg", 0.202),
),
damping=(
motor_params.get("motor_damping_pos", 0.013),
motor_params.get("motor_damping_neg", 0.015),
),
frictionloss=(
motor_params.get("motor_frictionloss_pos", 0.014),
motor_params.get("motor_frictionloss_neg", 0.001),
),
filter_tau=motor_params.get("actuator_filter_tau", 0.096),
viscous_quadratic=motor_params.get("viscous_quadratic", 0.0),
back_emf_gain=motor_params.get("back_emf_gain", 0.0),
stribeck_friction_boost=motor_params.get("stribeck_friction_boost", 0.0),
stribeck_vel=motor_params.get("stribeck_vel", 2.0),
action_bias=motor_params.get("action_bias", 0.0),
)
robot = RobotConfig(
urdf_path=tmp_urdf_path,
actuators=[actuator],
joints={
"motor_joint": JointConfig(
damping=0.0,
armature=motor_armature,
frictionloss=0.0,
),
"pendulum_joint": JointConfig(
damping=pend_damping,
frictionloss=pend_frictionloss,
),
},
)
try:
model = load_mujoco_model(robot)
finally:
tmp_urdf_path.unlink(missing_ok=True)
return model, actuator
# ── Simulation rollout ───────────────────────────────────────────────
def rollout(
robot_path: str | Path,
params: dict[str, float],
actions: np.ndarray,
sim_dt: float = 0.002,
substeps: int = 10,
motor_params: dict[str, float] | None = None,
) -> dict[str, np.ndarray]:
"""Replay recorded actions in MuJoCo with overridden parameters.
Parameters
----------
robot_path : asset directory
params : named parameter overrides (pendulum/arm only)
actions : (N,) normalised actions [-1, 1] from the recording
sim_dt : MuJoCo physics timestep
substeps : physics substeps per control step
motor_params : locked motor params (default: LOCKED_MOTOR_PARAMS)
Returns
-------
dict with keys: motor_angle, motor_vel, pendulum_angle, pendulum_vel
Each is an (N,) numpy array of simulated values.
"""
if motor_params is None:
motor_params = LOCKED_MOTOR_PARAMS
robot_path = Path(robot_path).resolve()
model, actuator = _build_model(robot_path, params, motor_params)
model.opt.timestep = sim_dt
data = mujoco.MjData(model)
mujoco.mj_resetData(model, data)
n = len(actions)
sim_motor_angle = np.zeros(n, dtype=np.float64)
sim_motor_vel = np.zeros(n, dtype=np.float64)
sim_pend_angle = np.zeros(n, dtype=np.float64)
sim_pend_vel = np.zeros(n, dtype=np.float64)
limits = ActuatorLimits(model)
for i in range(n):
# transform_ctrl clips to the (fitted) ctrl_range internally.
ctrl = actuator.transform_ctrl(float(actions[i]))
data.ctrl[0] = ctrl
for _ in range(substeps):
limits.enforce(model, data)
data.qfrc_applied[0] = actuator.compute_motor_force(data.qvel[0], ctrl)
mujoco.mj_step(model, data)
sim_motor_angle[i] = data.qpos[0]
sim_pend_angle[i] = data.qpos[1]
sim_motor_vel[i] = data.qvel[0]
sim_pend_vel[i] = data.qvel[1]
return {
"motor_angle": sim_motor_angle,
"motor_vel": sim_motor_vel,
"pendulum_angle": sim_pend_angle,
"pendulum_vel": sim_pend_vel,
}
def windowed_rollout(
robot_path: str | Path,
params: dict[str, float],
recording: dict[str, np.ndarray],
window_duration: float = 0.5,
sim_dt: float = 0.002,
substeps: int = 10,
motor_params: dict[str, float] | None = None,
) -> dict[str, np.ndarray | float]:
"""Multiple-shooting rollout — split recording into short windows.
For each window:
1. Initialize MuJoCo state from the real qpos/qvel at the window start.
2. Replay the recorded actions within the window.
3. Record the simulated output.
Motor dynamics (asymmetric friction, damping, back-EMF, etc.) are
applied via qfrc_applied using the locked motor sysid parameters.
Parameters
----------
robot_path : asset directory
params : named parameter overrides (pendulum/arm only)
recording : dict with keys time, action, motor_angle, motor_vel,
pendulum_angle, pendulum_vel (all 1D arrays of length N)
window_duration : length of each shooting window in seconds
sim_dt : MuJoCo physics timestep
substeps : physics substeps per control step
motor_params : locked motor params (default: LOCKED_MOTOR_PARAMS)
Returns
-------
dict with:
motor_angle, motor_vel, pendulum_angle, pendulum_vel — (N,) arrays
(stitched from per-window simulations)
n_windows — number of windows used
"""
if motor_params is None:
motor_params = LOCKED_MOTOR_PARAMS
robot_path = Path(robot_path).resolve()
model, actuator = _build_model(robot_path, params, motor_params)
model.opt.timestep = sim_dt
data = mujoco.MjData(model)
times = recording["time"]
actions = recording["action"]
real_motor = recording["motor_angle"]
real_motor_vel = recording["motor_vel"]
real_pend = recording["pendulum_angle"]
real_pend_vel = recording["pendulum_vel"]
n = len(actions)
sim_motor_angle = np.zeros(n, dtype=np.float64)
sim_motor_vel = np.zeros(n, dtype=np.float64)
sim_pend_angle = np.zeros(n, dtype=np.float64)
sim_pend_vel = np.zeros(n, dtype=np.float64)
limits = ActuatorLimits(model)
t0 = times[0]
t_end = times[-1]
window_starts: list[int] = []
current_t = t0
while current_t < t_end:
idx = int(np.searchsorted(times, current_t))
idx = min(idx, n - 1)
window_starts.append(idx)
current_t += window_duration
n_windows = len(window_starts)
for w, w_start in enumerate(window_starts):
w_end = window_starts[w + 1] if w + 1 < n_windows else n
mujoco.mj_resetData(model, data)
data.qpos[0] = real_motor[w_start]
data.qpos[1] = real_pend[w_start]
data.qvel[0] = real_motor_vel[w_start]
data.qvel[1] = real_pend_vel[w_start]
data.ctrl[:] = 0.0
mujoco.mj_forward(model, data)
for i in range(w_start, w_end):
# transform_ctrl clips to the (fitted) ctrl_range internally.
ctrl = actuator.transform_ctrl(float(actions[i]))
data.ctrl[0] = ctrl
for _ in range(substeps):
limits.enforce(model, data)
data.qfrc_applied[0] = actuator.compute_motor_force(data.qvel[0], ctrl)
mujoco.mj_step(model, data)
sim_motor_angle[i] = data.qpos[0]
sim_pend_angle[i] = data.qpos[1]
sim_motor_vel[i] = data.qvel[0]
sim_pend_vel[i] = data.qvel[1]
return {
"motor_angle": sim_motor_angle,
"motor_vel": sim_motor_vel,
"pendulum_angle": sim_pend_angle,
"pendulum_vel": sim_pend_vel,
"n_windows": n_windows,
}

248
src/sysid/visualize.py Normal file
View File

@@ -0,0 +1,248 @@
"""Visualise system identification results — real vs simulated trajectories.
Loads a recording and runs simulation with both the original and tuned
parameters, then plots a 4-panel comparison (motor angle, motor vel,
pendulum angle, pendulum vel) over time.
Usage:
python -m src.sysid.visualize \
--robot-path assets/rotary_cartpole \
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
# Also compare with tuned parameters:
python -m src.sysid.visualize \
--robot-path assets/rotary_cartpole \
--recording <file>.npz \
--result assets/rotary_cartpole/sysid_result.json
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import structlog
log = structlog.get_logger()
def _run_sim(
robot_path: Path,
params: dict[str, float],
recording: dict[str, np.ndarray],
window_duration: float,
sim_dt: float,
substeps: int,
) -> dict[str, np.ndarray]:
"""Run windowed or open-loop rollout depending on window_duration."""
from src.sysid.rollout import rollout, windowed_rollout
if window_duration > 0:
return windowed_rollout(
robot_path=robot_path, params=params, recording=recording,
window_duration=window_duration, sim_dt=sim_dt,
substeps=substeps,
)
return rollout(
robot_path=robot_path, params=params, actions=recording["action"],
substeps=substeps,
)
def visualize(
robot_path: str | Path,
recording_path: str | Path,
result_path: str | Path | None = None,
sim_dt: float = 0.002,
substeps: int = 10,
window_duration: float = 0.5,
save_path: str | Path | None = None,
show: bool = True,
) -> None:
"""Generate comparison plot."""
import matplotlib.pyplot as plt
from src.sysid.rollout import (
ROTARY_CARTPOLE_PARAMS,
defaults_vector,
params_to_dict,
)
robot_path = Path(robot_path).resolve()
recording = dict(np.load(recording_path))
sim_kwargs = dict(
robot_path=robot_path, recording=recording,
window_duration=window_duration, sim_dt=sim_dt,
substeps=substeps,
)
t = recording["time"]
actions = recording["action"]
# ── Simulate with default parameters ─────────────────────────
default_params = params_to_dict(
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
)
log.info("simulating_default_params", windowed=window_duration > 0)
sim_default = _run_sim(params=default_params, **sim_kwargs)
# ── Simulate with tuned parameters (if available) ────────────
# Resolve result path (explicit or auto-detect).
if result_path is None:
auto = robot_path / "sysid_result.json"
if auto.exists():
result_path = auto
sim_tuned = None
tuned_cost = None
if result_path is not None:
result_path = Path(result_path)
if result_path.exists():
result = json.loads(result_path.read_text())
tuned_params = result.get("best_params", {})
tuned_cost = result.get("best_cost")
log.info("simulating_tuned_params", cost=tuned_cost)
sim_tuned = _run_sim(params=tuned_params, **sim_kwargs)
else:
log.warning("result_file_not_found", path=str(result_path))
# ── Plot ─────────────────────────────────────────────────────
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
channels = [
("motor_angle", "Motor Angle (rad)"),
("motor_vel", "Motor Velocity (rad/s)"),
("pendulum_angle", "Pendulum Angle (rad)"),
("pendulum_vel", "Pendulum Velocity (rad/s)"),
]
for ax, (key, ylabel) in zip(axes[:4], channels):
real = recording[key]
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
ax.plot(
t,
sim_default[key],
"--",
color="#d62728",
linewidth=1.0,
alpha=0.7,
label="Sim (original)",
)
if sim_tuned is not None:
ax.plot(
t,
sim_tuned[key],
"--",
color="#2ca02c",
linewidth=1.0,
alpha=0.7,
label="Sim (tuned)",
)
ax.set_ylabel(ylabel)
ax.legend(loc="upper right", fontsize=8)
ax.grid(True, alpha=0.3)
# Action plot (bottom panel).
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
axes[4].set_ylabel("Action (norm)")
axes[4].set_xlabel("Time (s)")
axes[4].grid(True, alpha=0.3)
axes[4].set_ylim(-1.1, 1.1)
# Title with cost info.
title = "System Identification — Real vs Simulated Trajectories"
if tuned_cost is not None:
# Compute original cost for comparison.
from src.sysid.optimize import cost_function
orig_cost = cost_function(
defaults_vector(ROTARY_CARTPOLE_PARAMS),
recording,
robot_path,
ROTARY_CARTPOLE_PARAMS,
sim_dt=sim_dt,
substeps=substeps,
window_duration=window_duration,
)
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
title += f" ({improvement:+.1f}%)"
fig.suptitle(title, fontsize=12)
plt.tight_layout()
if save_path:
save_path = Path(save_path)
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
log.info("figure_saved", path=str(save_path))
if show:
plt.show()
else:
plt.close(fig)
# ── CLI entry point ──────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="Visualise system identification results."
)
parser.add_argument(
"--robot-path",
type=str,
default="assets/rotary_cartpole",
)
parser.add_argument(
"--recording",
type=str,
required=True,
help="Path to .npz recording file",
)
parser.add_argument(
"--result",
type=str,
default=None,
help="Path to sysid_result.json (auto-detected if omitted)",
)
parser.add_argument("--sim-dt", type=float, default=0.002)
parser.add_argument("--substeps", type=int, default=10)
parser.add_argument(
"--window-duration",
type=float,
default=0.5,
help="Shooting window length in seconds (0 = open-loop)",
)
parser.add_argument(
"--save",
type=str,
default=None,
help="Save figure to this path (PNG, PDF, …)",
)
parser.add_argument(
"--no-show",
action="store_true",
help="Don't show interactive window (useful for CI)",
)
args = parser.parse_args()
visualize(
robot_path=args.robot_path,
recording_path=args.recording,
result_path=args.result,
sim_dt=args.sim_dt,
substeps=args.substeps,
window_duration=args.window_duration,
save_path=args.save,
show=not args.no_show,
)
if __name__ == "__main__":
main()

View File

@@ -4,19 +4,24 @@ import tempfile
from pathlib import Path
import numpy as np
import structlog
import torch
import tqdm
from clearml import Logger
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.memories.torch import RandomMemory
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.trainers.torch import SequentialTrainer
from src.core.runner import BaseRunner
from clearml import Task, Logger
import torch
from gymnasium import spaces
from skrl.memories.torch import RandomMemory
from src.models.mlp import SharedMLP
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.trainers.torch import SequentialTrainer
log = structlog.get_logger()
@dataclasses.dataclass
class TrainerConfig:
# PPO
rollout_steps: int = 2048
learning_epochs: int = 8
mini_batches: int = 4
@@ -26,33 +31,41 @@ class TrainerConfig:
clip_ratio: float = 0.2
value_loss_scale: float = 0.5
entropy_loss_scale: float = 0.01
kl_threshold: float = 0.01 # KL-adaptive LR target; 0 = fixed LR
hidden_sizes: tuple[int, ...] = (64, 64)
# Policy
initial_log_std: float = 0.5 # initial exploration noise
min_log_std: float = -2.0 # minimum exploration noise
max_log_std: float = 2.0 # maximum exploration noise (2.0 ≈ σ=7.4)
# Training
total_timesteps: int = 1_000_000
log_interval: int = 10
checkpoint_interval: int = 50_000
# Video recording
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
# Video recording (uploaded to ClearML)
record_video_every: int = 10_000 # 0 = disabled
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
clearml_project: str | None = None
clearml_task: str | None = None
# History encoder (implicit adaptation). The window size comes from
# the runner (runner.history_length) — single source of truth.
embedding_dim: int = 32 # history encoder output dimension
# ── Video-recording trainer ──────────────────────────────────────────
class VideoRecordingTrainer(SequentialTrainer):
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
super().__init__(env=env, agents=agents, cfg=cfg)
self._trainer_config = trainer_config
self._tcfg = trainer_config
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
def single_agent_train(self) -> None:
"""Override to add periodic video recording."""
assert self.num_simultaneous_agents == 1
assert self.env.num_agents == 1
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
states, infos = self.env.reset()
@@ -61,26 +74,17 @@ class VideoRecordingTrainer(SequentialTrainer):
disable=self.disable_progressbar,
file=sys.stdout,
):
# Pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
with torch.no_grad():
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
if not self.headless:
self.env.render()
self.agents.record_transition(
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
terminated=terminated,
truncated=truncated,
infos=infos,
timestep=timestep,
timesteps=self.timesteps,
states=states, actions=actions, rewards=rewards,
next_states=next_states, terminated=terminated,
truncated=truncated, infos=infos,
timestep=timestep, timesteps=self.timesteps,
)
if self.environment_info in infos:
@@ -90,7 +94,7 @@ class VideoRecordingTrainer(SequentialTrainer):
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
# Reset environments
# Auto-reset for multi-env; single-env resets manually
if self.env.num_envs > 1:
states = next_states
else:
@@ -100,111 +104,125 @@ class VideoRecordingTrainer(SequentialTrainer):
else:
states = next_states
# Record video at intervals
cfg = self._trainer_config
# Periodic video recording. Recording steps the (shared) envs,
# so it returns a freshly reset observation — the training loop
# MUST continue from it, otherwise the recorded transitions no
# longer match the actual env state.
if (
cfg
and cfg.record_video_every > 0
and (timestep + 1) % cfg.record_video_every == 0
self._tcfg
and self._tcfg.record_video_every > 0
and (timestep + 1) % self._tcfg.record_video_every == 0
):
self._record_video(timestep + 1)
fresh_states = self._record_video(timestep + 1)
if fresh_states is not None:
states = fresh_states
def _get_video_fps(self) -> int:
"""Derive video fps from the simulation rate, or use configured value."""
cfg = self._trainer_config
if cfg.record_video_fps > 0:
return cfg.record_video_fps
# Auto-derive from runner's simulation parameters
runner = self.env
dt = getattr(runner.config, "dt", 0.02)
substeps = getattr(runner.config, "substeps", 1)
# ── helpers ───────────────────────────────────────────────────────
def _get_fps(self) -> int:
if self._tcfg and self._tcfg.record_video_fps > 0:
return self._tcfg.record_video_fps
dt = getattr(self.env.config, "dt", 0.02)
substeps = getattr(self.env.config, "substeps", 1)
# SerialRunner has dt but no substeps — dt *is* the control period.
return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None:
"""Record evaluation episodes and upload to ClearML."""
def _record_video(self, timestep: int) -> torch.Tensor | None:
"""Record an eval episode and upload it to ClearML.
Returns the freshly reset observation the training loop should
continue from (the recording steps the shared envs), or ``None``
if even the final reset failed.
"""
try:
import imageio.v3 as iio
except ImportError:
iio = None
# Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
# failure crash training — log it and skip the video.
if iio is not None:
try:
import imageio as iio
except ImportError:
return
fps = self._get_fps()
max_steps = getattr(self.env.env.config, "max_steps", 500)
frames: list[np.ndarray] = []
cfg = self._trainer_config
fps = self._get_video_fps()
min_frames = int(cfg.record_video_min_seconds * fps)
max_frames = min_frames * 3 # hard cap to prevent runaway recording
frames: list[np.ndarray] = []
while len(frames) < min_frames and len(frames) < max_frames:
obs, _ = self.env.reset()
done = False
steps = 0
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
while not done and steps < max_episode_steps:
obs, _ = self.env.reset()
with torch.no_grad():
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action)
frame = self.env.render(mode="rgb_array")
if frame is not None:
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
done = (terminated | truncated).any().item()
steps += 1
if len(frames) >= max_frames:
break
for _ in range(max_steps):
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action)
if frames:
video_path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(video_path, frames, fps=fps)
frame = self.env.render()
if frame is not None:
frames.append(frame)
logger = Logger.current_logger()
if logger:
logger.report_media(
title="Training Video",
series=f"step_{timestep}",
local_path=video_path,
iteration=timestep,
)
if (terminated | truncated).any().item():
break
# Reset back to training state after recording
self.env.reset()
if frames:
path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(path, frames, fps=fps)
logger = Logger.current_logger()
if logger:
logger.report_media(
"Training Video", f"step_{timestep}",
local_path=path, iteration=timestep,
)
except Exception as exc:
log.warning("video_recording_failed", timestep=timestep, error=str(exc))
# Always leave the envs freshly reset and hand the new observation
# back to the training loop.
try:
with torch.no_grad():
states, _ = self.env.reset()
return states
except Exception as exc:
log.warning("post_video_reset_failed", timestep=timestep, error=str(exc))
return None
# ── Main trainer ─────────────────────────────────────────────────────
class Trainer:
def __init__(self, runner: BaseRunner, config: TrainerConfig):
self.runner = runner
self.config = config
self._init_clearml()
self._init_agent()
def _init_clearml(self) -> None:
if self.config.clearml_project and self.config.clearml_task:
self.clearml_task = Task.init(
project_name=self.config.clearml_project,
task_name=self.config.clearml_task,
)
else:
self.clearml_task = None
def _init_agent(self) -> None:
device: torch.device = self.runner.device
obs_space: spaces.Space = self.runner.observation_space
act_space: spaces.Space = self.runner.action_space
num_envs: int = self.runner.num_envs
device = self.runner.device
obs_space = self.runner.observation_space
act_space = self.runner.action_space
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
self.memory = RandomMemory(
memory_size=self.config.rollout_steps,
num_envs=self.runner.num_envs,
device=device,
)
self.model: SharedMLP = SharedMLP(
# Determine raw obs dim (without history augmentation) and the
# history window size — both come from the runner so the model
# always matches the observation layout it produces.
raw_obs_dim = self.runner.env.observation_space.shape[0]
history_length = getattr(self.runner.config, "history_length", 0)
self.model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=self.config.hidden_sizes,
initial_log_std=self.config.initial_log_std,
min_log_std=self.config.min_log_std,
max_log_std=self.config.max_log_std,
history_length=history_length,
raw_obs_dim=raw_obs_dim,
embedding_dim=self.config.embedding_dim,
)
models = {
"policy": self.model,
"value": self.model,
}
models = {"policy": self.model, "value": self.model}
agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg.update({
@@ -217,9 +235,28 @@ class Trainer:
"ratio_clip": self.config.clip_ratio,
"value_loss_scale": self.config.value_loss_scale,
"entropy_loss_scale": self.config.entropy_loss_scale,
# Truncation (time limit) must bootstrap from the value function;
# without this the value target is biased at every max_steps cut.
"time_limit_bootstrap": True,
"state_preprocessor": RunningStandardScaler,
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
"value_preprocessor": RunningStandardScaler,
"value_preprocessor_kwargs": {"size": 1, "device": device},
})
if self.config.kl_threshold > 0:
from skrl.resources.schedulers.torch import KLAdaptiveLR
agent_cfg["learning_rate_scheduler"] = KLAdaptiveLR
agent_cfg["learning_rate_scheduler_kwargs"] = {
"kl_threshold": self.config.kl_threshold,
}
# Wire up logging frequency: write_interval is in timesteps.
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
agent_cfg["experiment"]["checkpoint_interval"] = max(
self.config.checkpoint_interval, self.config.rollout_steps
)
self.agent: PPO = PPO(
self.agent = PPO(
models=models,
memory=self.memory,
observation_space=obs_space,
@@ -238,6 +275,4 @@ class Trainer:
trainer.train()
def close(self) -> None:
self.runner.close()
if self.clearml_task:
self.clearml_task.close()
self.runner.close()

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
import sys
from pathlib import Path
# Make `src.*` importable regardless of pytest invocation directory.
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)

79
tests/test_reward.py Normal file
View File

@@ -0,0 +1,79 @@
"""Reward design tests — balancing must strictly dominate spinning."""
import math
from pathlib import Path
import pytest
import torch
from src.envs.rotary_cartpole import (
RotaryCartPoleConfig,
RotaryCartPoleEnv,
RotaryCartPoleState,
)
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
def _env() -> RotaryCartPoleEnv:
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
t = lambda v: torch.tensor([float(v)])
return RotaryCartPoleState(
motor_angle=t(motor), motor_vel=t(motor_vel),
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
)
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
a = torch.tensor([[float(action)]])
pa = torch.tensor([[float(prev_action)]])
return float(env.compute_rewards(state, a, pa)[0])
def test_balancing_beats_spinning_through_upright():
env = _env()
balanced = _state(pend=math.pi, pend_vel=0.0)
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
def test_average_spin_cycle_reward_below_balance():
"""Mean reward over a full revolution at high speed << balanced reward."""
env = _env()
angles = torch.linspace(0, 2 * math.pi, 32)
spin_rewards = [
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
]
mean_spin = sum(spin_rewards) / len(spin_rewards)
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
assert balanced > 3.0 * mean_spin
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
env = _env()
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
assert _reward(env, over_limit) == -10.0
assert bool(env.compute_terminations(over_limit)[0])
def test_action_rate_penalty_reduces_reward():
env = _env()
s = _state(pend=math.pi)
smooth = _reward(env, s, action=0.5, prev_action=0.5)
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
assert smooth > jerky
assert smooth - jerky == pytest.approx(
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
)
def test_initial_state_ranges_widen_pendulum_only():
env = _env()
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()

125
tests/test_robot_config.py Normal file
View File

@@ -0,0 +1,125 @@
"""Robot config loading + motor model unit tests."""
import math
from pathlib import Path
import pytest
import torch
from src.core.robot import ActuatorConfig, load_robot_config
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
# ── Loading ──────────────────────────────────────────────────────────
def test_canonical_robot_yaml_loads_full_motor_model():
robot = load_robot_config(ROBOT_DIR)
act = robot.actuators[0]
assert act.has_motor_model
# Tuned (unified sysid) values must survive the round-trip.
assert act.filter_tau == pytest.approx(0.096263)
assert act.stribeck_friction_boost == pytest.approx(0.068594)
assert act.stribeck_vel == pytest.approx(5.279594)
assert act.action_bias == pytest.approx(0.056566)
assert act.gear == pytest.approx((0.846499, 1.183733))
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
(tmp_path / "robot.yaml").write_text(
"urdf: dummy.urdf\n"
"actuators:\n"
" - joint: j\n"
" type: motor\n"
" gear: [1.0, 1.0]\n"
" some_future_field: 42\n"
)
robot = load_robot_config(tmp_path) # must not raise
assert robot.actuators[0].joint == "j"
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
@pytest.fixture
def act() -> ActuatorConfig:
return ActuatorConfig(
joint="m",
gear=(0.8, 1.2),
ctrl_range=(-0.6, 0.6),
deadzone=(0.15, 0.20),
frictionloss=(0.014, 0.001),
damping=(0.013, 0.015),
stribeck_friction_boost=0.07,
stribeck_vel=5.0,
action_bias=0.05,
)
def test_transform_ctrl_clips_to_ctrl_range(act):
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
out = act.transform_ctrl(1.0)
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
assert act.transform_ctrl(0.05) == 0.0
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
assert act.transform_ctrl(-0.15) == 0.0
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
assert pos == pytest.approx(0.55 * 0.8)
assert neg == pytest.approx(-0.45 * 1.2)
def test_transform_action_matches_transform_ctrl_elementwise(act):
vals = torch.linspace(-1.2, 1.2, 49)
batched = act.transform_action(vals.clone())
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
assert torch.allclose(batched, scalar, atol=1e-6)
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
def test_friction_opposes_motion(act):
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
def test_stribeck_boost_decays_with_speed(act):
"""Friction torque magnitude (minus damping) is higher near standstill."""
no_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
)
with_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
stribeck_friction_boost=0.07, stribeck_vel=5.0,
)
v_slow, v_fast = 0.1, 50.0
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
no_strb.compute_motor_force(v_slow, 0.0))
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
no_strb.compute_motor_force(v_fast, 0.0))
assert extra_slow == pytest.approx(
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
assert extra_fast < extra_slow
assert extra_fast == pytest.approx(0.0, abs=1e-9)
def test_friction_scale_dr_multiplies_friction(act):
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
# damping_scale=0 isolates the friction term
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
assert doubled == pytest.approx(2.0 * base)

173
tests/test_runner.py Normal file
View File

@@ -0,0 +1,173 @@
"""Runner integration tests — DR, history, action delay, init randomization.
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
is skipped when JAX isn't installed.
"""
import math
from pathlib import Path
import pytest
import torch
from src.core.registry import build_env
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
DR = {
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
}
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
cfg = MuJoCoRunnerConfig(
num_envs=num_envs,
device="cpu",
history_length=history_length,
domain_rand=domain_rand or {},
)
return MuJoCoRunner(env=env, config=cfg)
# ── Observation layout ───────────────────────────────────────────────
def test_obs_is_raw_plus_history():
runner = _runner(num_envs=2, history_length=3)
raw_dim = runner.env.observation_space.shape[0] # 6
step_dim = raw_dim + 1 # + action
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
obs, _ = runner.reset()
assert obs.shape == (2, raw_dim + 3 * step_dim)
# Fresh history must be zero.
assert torch.all(obs[:, raw_dim:] == 0)
actions = torch.full((2, 1), 0.3)
obs, rewards, term, trunc, info = runner.step(actions)
assert obs.shape == (2, raw_dim + 3 * step_dim)
assert rewards.shape == (2, 1)
# Newest history slot holds the commanded action.
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
runner.close()
def test_no_history_keeps_plain_obs():
runner = _runner(num_envs=2, history_length=0)
assert runner.observation_space.shape[0] == 6
runner.close()
# ── Domain randomization ─────────────────────────────────────────────
def test_dr_scales_sampled_within_ranges_and_resampled():
runner = _runner(num_envs=16, domain_rand=DR)
runner.reset()
for name, (lo, hi) in (
("friction_scale", (0.6, 1.6)),
("damping_scale", (0.6, 1.6)),
("torque_scale", (0.85, 1.15)),
):
vals = runner._dr_scales[name]
assert torch.all(vals >= lo) and torch.all(vals <= hi)
# 16 independent uniform samples are never all identical.
assert vals.std() > 0
before = runner._dr_scales["friction_scale"].clone()
runner.reset()
assert not torch.equal(before, runner._dr_scales["friction_scale"])
delays = runner._dr_delay
assert torch.all(delays >= 0) and torch.all(delays <= 2)
runner.close()
def test_dr_disabled_is_noop():
runner = _runner(num_envs=2, domain_rand={})
runner.reset()
for vals in runner._dr_scales.values():
assert torch.all(vals == 1.0)
assert runner._max_delay == 0
assert runner._qpos_noise_std == 0.0
runner.close()
def test_action_delay_buffer_returns_lagged_action():
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
runner.reset()
runner._dr_delay = torch.tensor([0, 1, 2])
runner._action_buf.zero_()
a1 = torch.tensor([[1.0], [1.0], [1.0]])
a2 = torch.tensor([[2.0], [2.0], [2.0]])
a3 = torch.tensor([[3.0], [3.0], [3.0]])
d1 = runner._apply_action_delay(a1)
d2 = runner._apply_action_delay(a2)
d3 = runner._apply_action_delay(a3)
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
runner.close()
# ── Initial-state randomization ──────────────────────────────────────
def test_wide_pendulum_init_actually_applied():
runner = _runner(num_envs=32)
qpos, _ = runner._sim_reset(torch.arange(32))
pend_angles = qpos[:, 1]
# With ±180° init range, samples must spread far beyond the old ±0.05.
assert pend_angles.abs().max() > 1.0
assert pend_angles.std() > 0.5
runner.close()
def test_sim_reset_returns_full_batch():
runner = _runner(num_envs=4)
runner.reset()
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
runner.close()
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
def test_mjx_runner_smoke():
pytest.importorskip("jax")
pytest.importorskip("mujoco.mjx")
from src.runners.mjx import MJXRunner, MJXRunnerConfig
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
runner = MJXRunner(
env=env,
config=MJXRunnerConfig(
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
),
)
obs, _ = runner.reset()
raw_dim = env.observation_space.shape[0]
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
for _ in range(3):
actions = torch.rand(4, 1) * 2 - 1
obs, rewards, term, trunc, _ = runner.step(actions)
assert torch.isfinite(obs).all()
assert torch.isfinite(rewards).all()
# Wide pendulum init must reach MJX resets too.
qpos, qvel = runner._sim_reset(torch.arange(4))
assert qpos.shape[0] == 4 # full batch
runner.close()

View File

@@ -1,47 +0,0 @@
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
from src.core.env import ActuatorConfig
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return CartPoleConfig(**env_dict)
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
env_config = _build_env_config(cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
# Build ClearML task name dynamically from Hydra config group choices
if not training_dict.get("clearml_task"):
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "env")
runner_name = choices.get("runner", "runner")
training_name = choices.get("training", "algo")
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
trainer_config = TrainerConfig(**training_dict)
env = CartPoleEnv(env_config)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
if __name__ == "__main__":
main()