♻️ full agent refactor

This commit is contained in:
2026-06-10 21:15:34 +02:00
parent a98e86ef66
commit 1e0836e1bc
49 changed files with 1309 additions and 829 deletions

4
.gitignore vendored
View File

@@ -6,6 +6,10 @@ outputs/
runs/ runs/
smac3_output/ smac3_output/
training_log.txt training_log.txt
.pytest_cache/
# Real-robot capture data (large .npz recordings)
assets/**/recordings/
# MuJoCo # MuJoCo
MUJOCO_LOG.TXT MUJOCO_LOG.TXT

64
README.md Normal file
View File

@@ -0,0 +1,64 @@
# RL-Framework
A small, fast RL framework for training sim2real policies on a 3D-printed
rotary (Furuta) cartpole — built to scale from a laptop CPU to a GPU worker
(ClearML) without code changes, and to grow into more robots and simulators.
## Architecture
Three orthogonal pieces, composed by Hydra config groups:
| Piece | Role | Implementations |
|---|---|---|
| **Env** (`src/envs/`) | Task logic: obs / reward / termination / init distribution. Pure torch, batched, backend-agnostic. | `rotary_cartpole` |
| **Runner** (`src/runners/`) | Physics + sim2real plumbing (DR, sensor noise, action delay, history buffer). | `mujoco` (CPU), `mjx` (GPU/JAX), `serial` (real ESP32 robot) |
| **Trainer** (`src/training/`) | skrl PPO + shared MLP with optional history encoder. | `ppo`, `ppo_mjx`, `ppo_single`, `ppo_real` |
The robot itself is described once in `assets/<robot>/robot.yaml`
(URDF + identified motor model) and shared by **training, sysid and
deployment** — the motor model (bias → deadzone → gear compensation,
Coulomb + Stribeck friction, viscous damping, first-order lag) is
implemented in `src/core/robot.py` and mirrored exactly in the MJX JIT
step (`src/runners/mjx.py`).
## Train
```bash
# CPU (64 parallel MuJoCo envs)
python scripts/train.py env=rotary_cartpole runner=mujoco training=ppo
# GPU (1024 MJX envs) — local
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx
# GPU — remote on ClearML gpu-queue
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx training.remote=true
```
Videos and scalars stream to ClearML. Checkpoints land in `runs/`.
## Sim2real recipe
1. **Capture** real trajectories: `python -m src.sysid.capture` (writes `.npz` to `assets/<robot>/recordings/`).
2. **Identify** physics: `python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz`
— CMA-ES fits inertials/joint dynamics against the recording (motor model is locked from the unified sysid). Writes `sysid_result.json` + `robot_tuned.yaml` + `*_tuned.urdf`.
3. **Validate** the fit: `python -m src.sysid.visualize`, then copy `robot_tuned.yaml``robot.yaml`.
4. **Train with DR + history**: the runner randomizes friction/damping/torque scales, sensor noise and action latency per episode (`configs/runner/mjx.yaml: domain_rand`), and appends a 10-step (obs, action) history to the observation so the policy can implicitly identify the current dynamics (`history_length`).
5. **Deploy**: `mjpython scripts/eval.py env=rotary_cartpole runner=serial checkpoint=runs/<run>/checkpoints/agent_X.pt`
## Other tools
```bash
mjpython scripts/viz.py env=rotary_cartpole # keyboard-drive the sim
mjpython scripts/viz.py runner=serial # digital twin of the real robot
python scripts/hpo.py env=rotary_cartpole training=ppo_single # ClearML + SMAC3 HPO
pytest tests/ # unit tests
```
## Adding a robot / simulator
- **Robot**: drop `assets/<name>/` (URDF + `robot.yaml`), subclass `BaseEnv`
(obs/reward/termination/`initial_state_ranges`), register in `src/core/registry.py`,
add `configs/env/<name>.yaml`.
- **Simulator**: subclass `BaseRunner` and implement `_sim_initialize`,
`_sim_step`, `_sim_reset` (full-batch return) — DR, history and the
env-side logic come for free. Register in `scripts/train.py: RUNNER_REGISTRY`.

View File

@@ -1,64 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<robot name="cartpole">
<!-- World link (fixed base) -->
<link name="world"/>
<!-- Cart (slides along x-axis) -->
<link name="cart">
<inertial>
<mass value="1.0"/>
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
</inertial>
<visual>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</visual>
<collision>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</collision>
</link>
<!-- Cart slides along x-axis -->
<joint name="cart_joint" type="prismatic">
<parent link="world"/>
<child link="cart"/>
<axis xyz="1 0 0"/>
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
</joint>
<!-- Pole (rotates around y-axis, attached on top of cart) -->
<link name="pole">
<inertial>
<origin xyz="0 0 0.3"/>
<mass value="0.1"/>
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
</inertial>
<visual>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</visual>
<collision>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</collision>
</link>
<!-- Pole rotates freely (no motor) -->
<joint name="pole_joint" type="revolute">
<parent link="cart"/>
<child link="pole"/>
<origin xyz="0 0 0.05"/>
<axis xyz="0 1 0"/>
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
<dynamics damping="0.0" friction="0.0"/>
</joint>
</robot>

View File

@@ -1,10 +0,0 @@
# Classic cartpole — robot hardware config.
urdf: cartpole.urdf
actuators:
- joint: cart_joint
type: motor
gear: 10.0
ctrl_range: [-1.0, 1.0]
damping: 0.05

View File

@@ -0,0 +1,10 @@
# Motor-only hardware config
# Encoder and motor constants for the motor-only sysid capture.
encoder:
ppr: 11 # pulses per revolution (before quadrature)
gear_ratio: 30.0 # gearbox ratio
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
motor:
max_pwm: 255 # maximum PWM command accepted by firmware

40
assets/motor/motor.xml Normal file
View File

@@ -0,0 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco model="motor_sysid">
<compiler angle="radian" autolimits="true"/>
<option timestep="0.002" integrator="Euler"/>
<worldbody>
<light pos="0 0 1" dir="0 0 -1"/>
<!-- Fixed base (gearbox housing) -->
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921"
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
<!-- Arm: rotates around motor_joint (z-axis) -->
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1"
range="-1.5708 1.5708"
damping="0.001" armature="0.0001" frictionloss="0.03"/>
<!-- Rotor disk: mass is tunable by the optimizer -->
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
contype="0" conaffinity="0"/>
<!-- Arm load: lightweight arm attached to motor shaft -->
<geom name="arm_load" type="capsule" size="0.004"
fromto="0 0 0 -0.014 0.002 0.016"
mass="0.021" rgba="0.8 0.3 0.1 1"
contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint"
gear="0.064"
ctrllimited="true" ctrlrange="-1 1"
dyntype="filter" dynprm="0.03"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Motor-only model WITHOUT arm/pendulum load.
Use for arm-off sysid captures. arm_load mass set to ~0. -->
<mujoco model="motor_sysid_bare">
<compiler angle="radian" autolimits="true"/>
<option timestep="0.002" integrator="Euler"/>
<worldbody>
<light pos="0 0 1" dir="0 0 -1"/>
<!-- Fixed base (gearbox housing) -->
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921"
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
<!-- Arm: rotates around motor_joint (z-axis) -->
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1"
range="-1.5708 1.5708"
damping="0.001" armature="0.0001" frictionloss="0.03"/>
<!-- Rotor disk: mass is tunable by the optimizer -->
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
contype="0" conaffinity="0"/>
<!-- Arm load: near-zero mass for bare-shaft sysid -->
<geom name="arm_load" type="capsule" size="0.004"
fromto="0 0 0 -0.014 0.002 0.016"
mass="0.001" rgba="0.8 0.3 0.1 0.3"
contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint"
gear="0.064"
ctrllimited="true" ctrlrange="-1 1"
dyntype="filter" dynprm="0.03"/>
</actuator>
</mujoco>

Binary file not shown.

After

Width:  |  Height:  |  Size: 463 KiB

View File

@@ -0,0 +1,67 @@
{
"best_params": {
"actuator_gear_pos": 0.3711939014035462,
"actuator_gear_neg": 0.42814281188601877,
"actuator_filter_tau": 0.022300731564787457,
"motor_damping_pos": 0.0013836218905629106,
"motor_damping_neg": 0.005196351489379768,
"motor_armature": 0.0027534181478216656,
"motor_frictionloss_pos": 0.03674406439012955,
"motor_frictionloss_neg": 0.06908200024786905,
"viscous_quadratic": 0.000958226218765762,
"back_emf_gain": 0.0036492272912788297,
"stribeck_friction_boost": 0.044748043677129666,
"stribeck_vel": 4.0513395945623705,
"rotor_mass": 0.03982640764507874,
"motor_deadzone_pos": 0.14181963932762467,
"motor_deadzone_neg": 0.031454276545010214,
"action_bias": -0.007362969452509152,
"gearbox_backlash": 1.4749880999407965e-09
},
"best_cost": 0.21167928018839952,
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/motor/recordings/motor_from_cartpole_162432.npz",
"param_names": [
"actuator_gear_pos",
"actuator_gear_neg",
"actuator_filter_tau",
"motor_damping_pos",
"motor_damping_neg",
"motor_armature",
"motor_frictionloss_pos",
"motor_frictionloss_neg",
"viscous_quadratic",
"back_emf_gain",
"stribeck_friction_boost",
"stribeck_vel",
"rotor_mass",
"motor_deadzone_pos",
"motor_deadzone_neg",
"action_bias",
"gearbox_backlash"
],
"defaults": {
"actuator_gear_pos": 0.064,
"actuator_gear_neg": 0.064,
"actuator_filter_tau": 0.03,
"motor_damping_pos": 0.003,
"motor_damping_neg": 0.003,
"motor_armature": 0.0001,
"motor_frictionloss_pos": 0.03,
"motor_frictionloss_neg": 0.03,
"viscous_quadratic": 0.0,
"back_emf_gain": 0.0,
"stribeck_friction_boost": 0.0,
"stribeck_vel": 2.0,
"rotor_mass": 0.012,
"motor_deadzone_pos": 0.08,
"motor_deadzone_neg": 0.08,
"action_bias": 0.0,
"gearbox_backlash": 0.0
},
"timestamp": "2026-03-23T20:50:53.648753",
"history_summary": {
"first_cost": 5.105499579419285,
"final_cost": 0.21167928018839952,
"generations": 500
}
}

View File

@@ -0,0 +1,19 @@
<?xml version='1.0' encoding='utf-8'?>
<mujoco model="motor_sysid">
<compiler angle="radian" autolimits="true" />
<option timestep="0.002" integrator="Euler" />
<worldbody>
<light pos="0 0 1" dir="0 0 -1" />
<body name="base" pos="0 0 0.15">
<geom type="cylinder" size="0.04 0.06" mass="0.921" rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0" />
<body name="arm" pos="0 0 0">
<joint name="motor_joint" type="hinge" axis="0 0 1" range="-1.5708 1.5708" damping="0" armature="0.0027534181478216656" frictionloss="0" />
<geom name="rotor_disk" type="cylinder" size="0.008 0.004" mass="0.03982640764507874" pos="0 0 0" rgba="0.6 0.6 0.6 1" contype="0" conaffinity="0" />
<geom name="arm_load" type="capsule" size="0.004" fromto="0 0 0 -0.014 0.002 0.016" mass="0.021" rgba="0.8 0.3 0.1 1" contype="0" conaffinity="0" />
</body>
</body>
</worldbody>
<actuator>
<general name="motor" joint="motor_joint" gear="0.3996683566447825" ctrllimited="true" ctrlrange="-1 1" dyntype="filter" dynprm="0.022300731564787457" />
</actuator>
</mujoco>

6
assets/motor/robot.yaml Normal file
View File

@@ -0,0 +1,6 @@
# Motor-only sysid config
# Minimal config for the motor-only identification pipeline.
# The optimizer patches motor.xml in-memory; this file tells it
# which MJCF to load and provides encoder/hardware constants.
mjcf: motor.xml

View File

@@ -0,0 +1,4 @@
# Motor-only sysid config — bare shaft (no arm/pendulum load)
# Use with arm-off captures to identify pure motor dynamics.
mjcf: motor_bare.xml

View File

@@ -0,0 +1,23 @@
# Tuned motor config — generated by src.sysid.motor.optimize
# Original: robot.yaml
mjcf: motor_tuned.xml
joints:
motor_joint:
armature: 0.002753
frictionloss: 0.052913
hardware_realism:
actuator_gear_pos: 0.371194
actuator_gear_neg: 0.428143
motor_damping_pos: 0.001384
motor_damping_neg: 0.005196
motor_frictionloss_pos: 0.036744
motor_frictionloss_neg: 0.069082
motor_deadzone_pos: 0.14182
motor_deadzone_neg: 0.031454
action_bias: -0.007363
viscous_quadratic: 0.000958
back_emf_gain: 0.003649
stribeck_friction_boost: 0.044748
stribeck_vel: 4.05134
gearbox_backlash: 0.0

View File

@@ -1,21 +1,30 @@
# Tuned robot config — generated by src.sysid.optimize # Canonical training model — unified sysid (cost 0.925, 475 generations).
# Source: sysid_result.json → exported via src.sysid.export.
# Key physics: ~96 ms motor lag (filter_tau), Stribeck friction, driver bias.
# Regenerate with:
# python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz
# then copy robot_tuned.yaml over this file once validated
# (python -m src.sysid.visualize to compare real vs sim).
urdf: rotary_cartpole_tuned.urdf
urdf: rotary_cartpole.urdf
actuators: actuators:
- joint: motor_joint - joint: motor_joint
type: motor type: motor
gear: [0.424182, 0.425031] # torque constant [pos, neg] (motor sysid) gear: [0.846499, 1.183733] # torque constant [pos, neg]
ctrl_range: [-0.592, 0.592] # effective control bound (sysid-tuned) ctrl_range: [-0.686251, 0.686251] # PWM saturation (MAX_MOTOR_SPEED / 255)
deadzone: [0.141291, 0.078015] # L298N min |ctrl| for torque [pos, neg] deadzone: [0.181097, 0.202072] # L298N min |ctrl| for torque [pos, neg]
damping: [0.002027, 0.014665] # viscous damping [pos, neg] damping: [0.013165, 0.015452] # viscous damping [pos, neg]
frictionloss: [0.057328, 0.053355] # Coulomb friction [pos, neg] frictionloss: [0.014244, 0.001005] # Coulomb friction [pos, neg]
filter_tau: 0.005035 # 1st-order actuator filter (motor sysid) filter_tau: 0.096263 # 1st-order actuator lag (s) — dominant!
viscous_quadratic: 0.000285 # velocity² drag stribeck_friction_boost: 0.068594 # extra static friction near standstill
back_emf_gain: 0.006758 # back-EMF torque reduction stribeck_vel: 5.279594 # Stribeck decay velocity (rad/s)
action_bias: 0.056566 # additive ctrl bias (driver asymmetry)
joints: joints:
motor_joint: motor_joint:
armature: 0.002773 # reflected rotor inertia (motor sysid) armature: 0.001676 # reflected rotor inertia (kg·m²)
frictionloss: 0.0 # handled by motor model via qfrc_applied frictionloss: 0.0 # handled by motor model via qfrc_applied
pendulum_joint: pendulum_joint:
damping: 0.000119 damping: 1.2e-05
frictionloss: 1.0e-05 frictionloss: 7.2e-05

View File

@@ -0,0 +1,34 @@
# Tuned robot config — generated by src.sysid.optimize
# Original: robot.yaml
# Run `python -m src.sysid.visualize` to compare real vs sim.
urdf: rotary_cartpole_tuned.urdf
actuators:
- joint: motor_joint
type: motor
gear:
- 0.846499
- 1.183733
ctrl_range:
- -0.686251
- 0.686251
deadzone:
- 0.181097
- 0.202072
damping:
- 0.013165
- 0.015452
frictionloss:
- 0.014244
- 0.001005
filter_tau: 0.096263
stribeck_friction_boost: 0.068594
stribeck_vel: 5.279594
action_bias: 0.056566
joints:
motor_joint:
armature: 0.001676
frictionloss: 0.0
pendulum_joint:
damping: 1.2e-05
frictionloss: 7.2e-05

View File

@@ -0,0 +1,80 @@
<?xml version='1.0' encoding='utf-8'?>
<robot name="rotary_cartpole">
<link name="world" />
<link name="base_link">
<inertial>
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
<mass value="0.921" />
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
</inertial>
<visual>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0 0 0" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="base_joint" type="fixed">
<parent link="world" />
<child link="base_link" />
</joint>
<link name="arm">
<inertial>
<origin xyz="0.02679980831009001 -0.015110803875962989 -0.005337417994989926" rpy="0 0 0" />
<mass value="0.04052645463607292" />
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="motor_joint" type="revolute">
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
<parent link="base_link" />
<child link="arm" />
<axis xyz="0 0 1" />
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
<dynamics damping="0.001" />
</joint>
<link name="pendulum">
<inertial>
<origin xyz="0.04237886229290564 -0.05762212306183831 -0.0006039324398328591" rpy="0 0 0" />
<mass value="0.056793277126969105" />
<inertia ixx="0.0005956997356690322" iyy="8.986080748758447e-05" izz="0.0006855370063143978" ixy="-1.074983574165622e-05" iyz="0.0" ixz="0.0" />
</inertial>
<visual>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</visual>
<collision>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
</geometry>
</collision>
</link>
<joint name="pendulum_joint" type="continuous">
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
<parent link="arm" />
<child link="pendulum" />
<axis xyz="0 -1 0" />
<dynamics damping="0.0001" />
</joint>
</robot>

Binary file not shown.

After

Width:  |  Height:  |  Size: 845 KiB

View File

@@ -0,0 +1,101 @@
{
"best_params": {
"actuator_gear_pos": 0.8464986357922419,
"actuator_gear_neg": 1.1837332515116656,
"actuator_filter_tau": 0.09626340711982824,
"motor_damping_pos": 0.013165358841570964,
"motor_damping_neg": 0.015451769431035585,
"motor_armature": 0.0016762295220909746,
"motor_frictionloss_pos": 0.014244285616762612,
"motor_frictionloss_neg": 0.0010053574956085138,
"stribeck_friction_boost": 0.06859381652654695,
"stribeck_vel": 5.279593819777324,
"motor_deadzone_pos": 0.1810971675049997,
"motor_deadzone_neg": 0.20207167371643614,
"action_bias": 0.05656632270757292,
"arm_mass": 0.04052645463607292,
"arm_com_x": 0.02679980831009001,
"arm_com_y": -0.015110803875962989,
"arm_com_z": -0.005337417994989926,
"pendulum_mass": 0.056793277126969105,
"pendulum_com_x": 0.04237886229290564,
"pendulum_com_y": -0.05762212306183831,
"pendulum_com_z": -0.0006039324398328591,
"pendulum_ixx": 0.0005956997356690322,
"pendulum_iyy": 8.986080748758447e-05,
"pendulum_izz": 0.0006855370063143978,
"pendulum_ixy": -1.074983574165622e-05,
"pendulum_damping": 1.1718327130851149e-05,
"pendulum_frictionloss": 7.204244487092445e-05,
"ctrl_limit": 0.6862506602999546
},
"best_cost": 0.9250391788859942,
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260328_153749.npz",
"param_names": [
"actuator_gear_pos",
"actuator_gear_neg",
"actuator_filter_tau",
"motor_damping_pos",
"motor_damping_neg",
"motor_armature",
"motor_frictionloss_pos",
"motor_frictionloss_neg",
"stribeck_friction_boost",
"stribeck_vel",
"motor_deadzone_pos",
"motor_deadzone_neg",
"action_bias",
"arm_mass",
"arm_com_x",
"arm_com_y",
"arm_com_z",
"pendulum_mass",
"pendulum_com_x",
"pendulum_com_y",
"pendulum_com_z",
"pendulum_ixx",
"pendulum_iyy",
"pendulum_izz",
"pendulum_ixy",
"pendulum_damping",
"pendulum_frictionloss",
"ctrl_limit"
],
"defaults": {
"actuator_gear_pos": 0.371194,
"actuator_gear_neg": 0.428143,
"actuator_filter_tau": 0.022301,
"motor_damping_pos": 0.001384,
"motor_damping_neg": 0.005196,
"motor_armature": 0.002753,
"motor_frictionloss_pos": 0.036744,
"motor_frictionloss_neg": 0.069082,
"stribeck_friction_boost": 0.0,
"stribeck_vel": 2.0,
"motor_deadzone_pos": 0.14182,
"motor_deadzone_neg": 0.031454,
"action_bias": 0.0,
"arm_mass": 0.0211,
"arm_com_x": -0.0071,
"arm_com_y": 0.00085,
"arm_com_z": 0.00795,
"pendulum_mass": 0.03937,
"pendulum_com_x": 0.06025,
"pendulum_com_y": -0.07602,
"pendulum_com_z": -0.00346,
"pendulum_ixx": 6.2e-05,
"pendulum_iyy": 3.7e-05,
"pendulum_izz": 7.83e-05,
"pendulum_ixy": -6.93e-06,
"pendulum_damping": 0.0001,
"pendulum_frictionloss": 0.0001,
"ctrl_limit": 0.588
},
"preprocess_vel": true,
"timestamp": "2026-03-28T17:09:39.241413",
"history_summary": {
"first_cost": 13.490216059926947,
"final_cost": 0.9250391788859942,
"generations": 475
}
}

View File

@@ -1,5 +1,5 @@
defaults: defaults:
- env: cartpole - env: rotary_cartpole
- runner: mujoco - runner: mujoco
- training: ppo - training: ppo
- _self_ - _self_

View File

@@ -1,7 +0,0 @@
max_steps: 500
robot_path: assets/cartpole
angle_threshold: 0.418
cart_limit: 2.4
reward_alive: 1.0
reward_pole_upright_scale: 1.0
reward_action_penalty_scale: 0.01

View File

@@ -9,6 +9,7 @@ balance_vel_scale: 0.5 # how fast the balance bonus decays with pendul
motor_vel_penalty: 0.01 # penalise high motor angular velocity motor_vel_penalty: 0.01 # penalise high motor angular velocity
motor_angle_penalty: 0.05 # penalise deviation from centre motor_angle_penalty: 0.05 # penalise deviation from centre
action_penalty: 0.05 # penalise large actions (energy cost) action_penalty: 0.05 # penalise large actions (energy cost)
action_rate_penalty: 0.01 # penalise action changes (real-motor smoothness)
# ── Initial state randomisation ────────────────────────────────────── # ── Initial state randomisation ──────────────────────────────────────
pendulum_init_range_deg: 180.0 # pendulum starts in [-180°, +180°] pendulum_init_range_deg: 180.0 # pendulum starts in [-180°, +180°]
@@ -22,5 +23,6 @@ hpo:
motor_vel_penalty: {min: 0.001, max: 0.1} motor_vel_penalty: {min: 0.001, max: 0.1}
motor_angle_penalty: {min: 0.01, max: 0.2} motor_angle_penalty: {min: 0.01, max: 0.2}
action_penalty: {min: 0.01, max: 0.2} action_penalty: {min: 0.01, max: 0.2}
action_rate_penalty: {min: 0.001, max: 0.1}
pendulum_init_range_deg: {min: 30.0, max: 180.0} pendulum_init_range_deg: {min: 30.0, max: 180.0}
max_steps: {values: [500, 1000, 2000]} max_steps: {values: [500, 1000, 2000]}

View File

@@ -2,9 +2,7 @@ num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 10 substeps: 10
history_length: 10 # RMA-style: 10-step window of (obs, action) pairs history_length: 10 # (obs, action) window for implicit adaptation
rma_mode: "none" # "none" | "teacher" | "deploy"
# ── Domain randomization (sim-to-real) ────────────────────────────── # ── Domain randomization (sim-to-real) ──────────────────────────────
# Full DR on GPU: latency + sensor noise + per-env dynamics scales # Full DR on GPU: latency + sensor noise + per-env dynamics scales

View File

@@ -2,9 +2,7 @@ num_envs: 64
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 10 substeps: 10
history_length: 10 # must match training.history_length (DR + embedding) history_length: 10 # (obs, action) window for implicit adaptation
rma_mode: "none" # "none" | "teacher" | "deploy"
# ── Domain randomization (sim-to-real) ────────────────────────────── # ── Domain randomization (sim-to-real) ──────────────────────────────
# Noise/delay levels anchored to the real recordings (~50 Hz, ~0.5 rad/s # Noise/delay levels anchored to the real recordings (~50 Hz, ~0.5 rad/s

View File

@@ -7,8 +7,6 @@ dt: 0.002
substeps: 10 substeps: 10
history_length: 10 history_length: 10
rma_mode: "none" # "none" | "teacher" | "deploy"
# Clean by default (deterministic eval). Confirming-experiment example — # Clean by default (deterministic eval). Confirming-experiment example —
# re-eval an existing checkpoint in sim with a fixed 1-step action delay: # re-eval an existing checkpoint in sim with a fixed 1-step action delay:
# mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \ # mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \

View File

@@ -9,5 +9,3 @@ baud: 115200
dt: 0.02 # control loop period (50 Hz, matches training) dt: 0.02 # control loop period (50 Hz, matches training)
no_data_timeout: 2.0 # seconds of silence before declaring disconnect no_data_timeout: 2.0 # seconds of silence before declaring disconnect
history_length: 10 # must match training runner history_length: 10 # must match training runner
rma_mode: "none" # "none" | "teacher" | "deploy"

View File

@@ -1,14 +1,18 @@
# PPO defaults — sized for the CPU MuJoCo runner (64 parallel envs).
# 128 rollout steps × 64 envs ≈ 8K samples per update.
hidden_sizes: [256, 256] hidden_sizes: [256, 256]
total_timesteps: 5000000 total_timesteps: 500000 # × 64 envs = 32M env steps
rollout_steps: 2048 rollout_steps: 128
learning_epochs: 10 learning_epochs: 5
mini_batches: 8 mini_batches: 4
discount_factor: 0.99 discount_factor: 0.99
gae_lambda: 0.95 gae_lambda: 0.95
learning_rate: 0.0003 learning_rate: 0.0003
clip_ratio: 0.2 clip_ratio: 0.2
value_loss_scale: 0.5 value_loss_scale: 0.5
entropy_loss_scale: 0.01 entropy_loss_scale: 0.01
kl_threshold: 0.01 # KL-adaptive LR; 0 = fixed learning rate
log_interval: 1000 log_interval: 1000
checkpoint_interval: 50000 checkpoint_interval: 50000
@@ -18,13 +22,9 @@ max_log_std: 2.0
record_video_every: 10000 record_video_every: 10000
# RMA-style history encoder # History encoder output dim — the window size itself comes from
history_length: 10 # temporal window (must match runner) # runner.history_length (single source of truth).
embedding_dim: 32 # history encoder output dimension embedding_dim: 32
# RMA (Rapid Motor Adaptation)
rma_mode: "none" # "none" | "teacher" | "deploy"
latent_dim: 8 # env encoder / adaptation latent dimension
# ClearML remote execution (GPU worker) # ClearML remote execution (GPU worker)
remote: false remote: false

View File

@@ -1,15 +1,20 @@
# PPO tuned for MJX (1024+ parallel envs on GPU). # PPO sized for MJX (1024+ parallel envs on GPU).
# Inherits defaults + HPO ranges from ppo.yaml. # Inherits defaults + HPO ranges from ppo.yaml.
# With 1024 envs, each timestep collects 1024 samples, so total_timesteps #
# can be much lower than the CPU config. # Short rollouts × many envs is the GPU-PPO sweet spot:
# 24 steps × 1024 envs ≈ 25K samples per update (~6K per mini-batch).
# (The old rollout_steps=2048 inherited from the CPU config meant a
# 2M-sample memory per update — GBs of VRAM and glacial updates.)
defaults: defaults:
- ppo - ppo
- _self_ - _self_
total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps rollout_steps: 24
mini_batches: 32 # keep mini-batch size similar (~32K) mini_batches: 4
learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling) learning_epochs: 5
learning_rate: 0.0003 # KL-adaptive scheduler handles the rest
total_timesteps: 100000 # × 1024 envs ≈ 100M env steps
log_interval: 100 log_interval: 100
checkpoint_interval: 10000 checkpoint_interval: 10000

View File

@@ -75,14 +75,13 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None: def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
"""Return the history/adaptation encoder output dim, if present. """Return the history encoder output dim, if present.
Lets eval reconstruct an embedding policy without knowing the training Lets eval reconstruct an embedding policy without knowing the training
embedding_dim/latent_dim — read it straight from the saved weights. embedding_dim — read it straight from the saved weights.
""" """
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"): if "history_encoder.fc.weight" in state_dict:
if key in state_dict: return state_dict["history_encoder.fc.weight"].shape[0]
return state_dict[key].shape[0]
return None return None
@@ -92,14 +91,13 @@ def load_policy(
action_space: spaces.Space, action_space: spaces.Space,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
history_length: int = 0, history_length: int = 0,
rma_mode: str = "none",
raw_obs_dim: int = 0, raw_obs_dim: int = 0,
) -> tuple[SharedMLP, RunningStandardScaler]: ) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint. """Load a trained SharedMLP + observation normalizer from a checkpoint.
For DR + history-embedding policies (history_length > 0) or RMA deploy For DR + history-embedding policies (history_length > 0), the history
policies (rma_mode="deploy"), the history/adaptation encoder must be encoder is reconstructed too — its output dim is read back from the
reconstructed too — its output dim is read back from the saved weights. saved weights.
Returns: Returns:
(model, state_preprocessor) ready for inference. (model, state_preprocessor) ready for inference.
@@ -117,11 +115,9 @@ def load_policy(
action_space=action_space, action_space=action_space,
device=device, device=device,
hidden_sizes=hidden_sizes, hidden_sizes=hidden_sizes,
history_length=history_length, history_length=history_length if enc_out else 0,
rma_mode=rma_mode,
raw_obs_dim=raw_obs_dim, raw_obs_dim=raw_obs_dim,
embedding_dim=enc_out or 32, # legacy "none" + history embedding_dim=enc_out or 32,
latent_dim=enc_out or 8, # RMA deploy adaptation module
) )
model.load_state_dict(ckpt["policy"]) model.load_state_dict(ckpt["policy"])
model.eval() model.eval()
@@ -189,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
@hydra.main(version_base=None, config_path="../configs", config_name="config") @hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco_single") runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None) checkpoint_path = cfg.get("checkpoint", None)
@@ -222,7 +218,6 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy( model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device, checkpoint_path, runner.observation_space, runner.action_space, device,
history_length=runner.config.history_length, history_length=runner.config.history_length,
rma_mode=runner.config.rma_mode,
raw_obs_dim=runner.env.observation_space.shape[0], raw_obs_dim=runner.env.observation_space.shape[0],
) )
@@ -311,7 +306,6 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy( model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device, checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
history_length=serial_runner.config.history_length, history_length=serial_runner.config.history_length,
rma_mode=serial_runner.config.rma_mode,
raw_obs_dim=serial_runner.env.observation_space.shape[0], raw_obs_dim=serial_runner.env.observation_space.shape[0],
) )
@@ -339,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
if _reset_flag[0]: if _reset_flag[0]:
_reset_flag[0] = False _reset_flag[0] = False
serial_runner._send("M0") serial_runner._send("M0")
serial_runner._drive_to_center() obs, _ = serial_runner.reset() # drives to center + settles
serial_runner._wait_for_pendulum_still()
obs, _ = serial_runner.reset()
step = 0 step = 0
episode += 1 episode += 1
episode_reward = 0.0 episode_reward = 0.0
@@ -376,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"step", n=step, reward=round(reward.item(), 3), "step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2), action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1), ep_reward=round(episode_reward, 1),
motor_enc=state["encoder_count"], motor_deg=round(math.degrees(state["motor_rad"]), 1),
pend_deg=round(state["pendulum_angle"], 1), pend_deg=round(math.degrees(state["pend_rad"]), 1),
) )
# Check for safety / disconnection. # Check for safety / disconnection.

View File

@@ -352,7 +352,7 @@ def main() -> None:
reuse_last_task_id=False, reuse_last_task_id=False,
) )
task.set_base_docker( task.set_base_docker(
docker_image="registry.kube.optimize/worker-image:latest", docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
docker_arguments=[ docker_arguments=[
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1", "-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1", "-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",

View File

@@ -63,7 +63,7 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags.""" """Initialize ClearML task with project structure and tags."""
Task.ignore_requirements("torch") Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco") runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo") training_name = choices.get("training", "ppo")
@@ -113,7 +113,7 @@ def main(cfg: DictConfig) -> None:
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)} _valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys} training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "rotary_cartpole")
env = build_env(env_name, cfg) env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg) runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict) trainer_config = TrainerConfig(**training_dict)

View File

@@ -1,296 +0,0 @@
"""RMA Phase 2: Train the adaptation module φ(history) → ẑ.
Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder,
and trains a HistoryEncoder (adaptation module) to predict the teacher's
latent z from observation-action history using supervised MSE.
Usage:
python scripts/train_adaptation.py \
--checkpoint runs/<run>/checkpoints/agent_XXXXX.pt \
--env rotary_cartpole \
--robot-path assets/rotary_cartpole \
--num-envs 64 \
--iterations 2000 \
--lr 3e-4
"""
import argparse
import pathlib
import sys
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import structlog
import torch
import tqdm
from gymnasium import spaces
from omegaconf import OmegaConf
from src.core.registry import build_env
from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
log = structlog.get_logger()
def _load_teacher_checkpoint(
path: str, obs_space: spaces.Space, act_space: spaces.Space,
device: torch.device, raw_obs_dim: int, mu_dim: int,
hidden_sizes: tuple[int, ...], latent_dim: int,
) -> SharedMLP:
"""Reconstruct the teacher SharedMLP and load saved weights."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="teacher",
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
latent_dim=latent_dim,
)
ckpt = torch.load(path, map_location=device, weights_only=True)
# skrl saves under "policy" key with state_dict.
if "policy" in ckpt:
model.load_state_dict(ckpt["policy"])
else:
model.load_state_dict(ckpt)
return model
def _build_deploy_model(
teacher: SharedMLP,
obs_space: spaces.Space,
act_space: spaces.Space,
device: torch.device,
raw_obs_dim: int,
history_length: int,
hidden_sizes: tuple[int, ...],
latent_dim: int,
) -> SharedMLP:
"""Create a deploy-mode SharedMLP and copy backbone + heads from teacher."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="deploy",
raw_obs_dim=raw_obs_dim,
history_length=history_length,
latent_dim=latent_dim,
)
# Copy backbone, policy head, value head from teacher.
model.net.load_state_dict(teacher.net.state_dict())
model.mean_layer.load_state_dict(teacher.mean_layer.state_dict())
model.value_layer.load_state_dict(teacher.value_layer.state_dict())
model.log_std_parameter.data.copy_(teacher.log_std_parameter.data)
return model
def main() -> None:
parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module")
parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint")
parser.add_argument("--env", default="rotary_cartpole")
parser.add_argument("--robot-path", default="assets/rotary_cartpole")
parser.add_argument("--num-envs", type=int, default=64)
parser.add_argument("--iterations", type=int, default=2000)
parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--latent-dim", type=int, default=8)
parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128])
parser.add_argument("--history-length", type=int, default=10)
parser.add_argument("--output", default="checkpoints/adaptation.pt")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
device = torch.device(args.device)
hidden_sizes = tuple(args.hidden_sizes)
# ── Build env + runner (deploy mode with history + DR) ───────
env_cfg = OmegaConf.create({"env": {
"robot_path": args.robot_path,
}})
env = build_env(args.env, env_cfg)
runner_cfg = MuJoCoRunnerConfig(
num_envs=args.num_envs,
device=args.device,
history_length=args.history_length,
rma_mode="deploy",
domain_rand={
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
},
)
runner = MuJoCoRunner(env=env, config=runner_cfg)
raw_obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
mu_dim = runner.privileged_dim
log.info(
"adaptation_setup",
raw_obs_dim=raw_obs_dim,
act_dim=act_dim,
mu_dim=mu_dim,
latent_dim=args.latent_dim,
history_length=args.history_length,
)
# ── Load teacher & build deploy model ────────────────────────
# Teacher obs space: [raw_obs, μ]
teacher_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,),
)
teacher = _load_teacher_checkpoint(
path=args.checkpoint,
obs_space=teacher_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
teacher.eval()
for p in teacher.parameters():
p.requires_grad_(False)
# Deploy obs space: [raw_obs, history_flat]
step_dim = raw_obs_dim + act_dim
deploy_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf,
shape=(raw_obs_dim + args.history_length * step_dim,),
)
deploy_model = _build_deploy_model(
teacher=teacher,
obs_space=deploy_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
history_length=args.history_length,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
# Freeze everything except the adaptation module.
for name, param in deploy_model.named_parameters():
if "adaptation_module" not in name:
param.requires_grad_(False)
optimizer = torch.optim.Adam(
deploy_model.adaptation_module.parameters(), lr=args.lr,
)
# ── Training loop ────────────────────────────────────────────
log.info("starting_adaptation_training", iterations=args.iterations)
obs, _ = runner.reset()
for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"):
# Collect a rollout using the deploy model.
z_targets: list[torch.Tensor] = []
z_preds: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
# Get action from deploy model (uses adaptation module).
aug_obs = obs # already augmented by runner
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
# Compute teacher's z from privileged μ.
mu = info.get("privileged_obs")
if mu is not None:
z_target = teacher.env_encoder(mu)
z_targets.append(z_target)
# Compute adaptation module's ẑ from history.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds.append(z_pred)
if not z_targets:
continue
# Supervised update on adaptation module.
z_target_batch = torch.cat(z_targets, dim=0).detach()
z_pred_batch = torch.cat(z_preds, dim=0)
# Re-compute z_pred with gradients (the ones above were no_grad).
# We need to re-encode from stored data; instead, collect with grad:
# Actually, z_preds were computed in no_grad. Let me re-collect
# a fresh batch with gradients.
obs_reset, _ = runner.reset()
obs = obs_reset
z_targets_grad: list[torch.Tensor] = []
z_preds_grad: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
aug_obs = obs
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
mu = info.get("privileged_obs")
if mu is not None:
with torch.no_grad():
z_target = teacher.env_encoder(mu)
z_targets_grad.append(z_target)
# This time, compute z_pred WITH gradients.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds_grad.append(z_pred)
if not z_targets_grad:
continue
z_target_all = torch.cat(z_targets_grad, dim=0).detach()
z_pred_all = torch.cat(z_preds_grad, dim=0)
loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if iteration % 50 == 0:
log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}")
# ── Save adaptation weights ──────────────────────────────────
out_path = pathlib.Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
# Save the full deploy model state dict.
torch.save(deploy_model.state_dict(), out_path)
log.info("adaptation_saved", path=str(out_path))
runner.close()
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@
Usage (simulation): Usage (simulation):
mjpython scripts/viz.py env=rotary_cartpole mjpython scripts/viz.py env=rotary_cartpole
mjpython scripts/viz.py env=cartpole +com=true mjpython scripts/viz.py env=rotary_cartpole +com=true
Usage (real hardware — digital twin): Usage (real hardware — digital twin):
mjpython scripts/viz.py env=rotary_cartpole runner=serial mjpython scripts/viz.py env=rotary_cartpole runner=serial
@@ -104,7 +104,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
@hydra.main(version_base=None, config_path="../configs", config_name="config") @hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco") runner_name = choices.get("runner", "mujoco")
if runner_name == "serial": if runner_name == "serial":
@@ -229,11 +229,12 @@ def _main_serial(cfg: DictConfig, env_name: str) -> None:
_reset_flag[0] = False _reset_flag[0] = False
serial_runner._send("M0") serial_runner._send("M0")
serial_runner._drive_to_center() serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still() serial_runner._wait_for_settle()
logger.info("reset (drive-to-center + settle)") logger.info("reset (drive-to-center + settle)")
# Send motor command to real hardware. # Send motor command to real hardware (same PWM scaling as
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255) # the policy path: ctrl_range-limited).
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
serial_runner._send(f"M{motor_speed}") serial_runner._send(f"M{motor_speed}")
# Sync MuJoCo model with real sensor data. # Sync MuJoCo model with real sensor data.

View File

@@ -1,8 +1,10 @@
import abc import abc
import dataclasses import dataclasses
from typing import TypeVar, Generic, Any from typing import TypeVar, Generic, Any
from gymnasium import spaces
import numpy as np
import torch import torch
from gymnasium import spaces
from src.core.robot import RobotConfig, load_robot_config from src.core.robot import RobotConfig, load_robot_config
@@ -38,7 +40,9 @@ class BaseEnv(abc.ABC, Generic[T]):
... ...
@abc.abstractmethod @abc.abstractmethod
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor: def compute_rewards(
self, state: Any, actions: torch.Tensor, prev_actions: torch.Tensor,
) -> torch.Tensor:
... ...
@abc.abstractmethod @abc.abstractmethod
@@ -48,6 +52,21 @@ class BaseEnv(abc.ABC, Generic[T]):
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor: def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
return step_counts >= self.config.max_steps return step_counts >= self.config.max_steps
def initial_state_ranges(
self, nq: int, nv: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Per-DOF uniform ranges for initial-state randomization.
Returns (qpos_lo, qpos_hi, qvel_lo, qvel_hi) — offsets added to the
model's default state on every reset. All runners (CPU MuJoCo and
MJX) sample from these, so initial-state distributions stay
identical across backends. Default: small ±0.05 perturbation.
"""
return (
np.full(nq, -0.05), np.full(nq, 0.05),
np.full(nv, -0.05), np.full(nv, 0.05),
)
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool: def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
"""Check whether the physical robot has settled enough to start an episode. """Check whether the physical robot has settled enough to start an episode.

View File

@@ -3,12 +3,10 @@
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig from src.core.env import BaseEnv, BaseEnvConfig
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
# Maps Hydra config-group name → (EnvClass, ConfigClass) # Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
} }

View File

@@ -15,6 +15,7 @@ import math
from pathlib import Path from pathlib import Path
import structlog import structlog
import torch
import yaml import yaml
log = structlog.get_logger() log = structlog.get_logger()
@@ -51,6 +52,9 @@ class ActuatorConfig:
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
viscous_quadratic: float = 0.0 # velocity² drag coefficient viscous_quadratic: float = 0.0 # velocity² drag coefficient
back_emf_gain: float = 0.0 # back-EMF torque reduction back_emf_gain: float = 0.0 # back-EMF torque reduction
stribeck_friction_boost: float = 0.0 # extra static friction at low speed (N·m)
stribeck_vel: float = 2.0 # Stribeck decay velocity (rad/s)
action_bias: float = 0.0 # additive ctrl bias (driver asymmetry)
@property @property
def gear_avg(self) -> float: def gear_avg(self) -> float:
@@ -66,10 +70,23 @@ class ActuatorConfig:
or self.frictionloss != (0.0, 0.0) or self.frictionloss != (0.0, 0.0)
or self.viscous_quadratic > 0 or self.viscous_quadratic > 0
or self.back_emf_gain > 0 or self.back_emf_gain > 0
or self.stribeck_friction_boost > 0
or self.action_bias != 0.0
) )
def transform_ctrl(self, ctrl: float) -> float: def transform_ctrl(self, ctrl: float) -> float:
"""Apply asymmetric deadzone and gear compensation to a scalar ctrl.""" """Clip to ctrl_range, then apply bias, deadzone and gear compensation.
Must stay in lock-step with the vectorised JAX version in
``src/runners/mjx.py`` (step_fn) — sysid fits parameters against
THIS function, so any drift breaks the identified model.
"""
# Clip to ctrl_range first (mirrors firmware PWM saturation).
ctrl = max(self.ctrl_range[0], min(self.ctrl_range[1], ctrl))
# Additive driver bias (e.g. H-bridge asymmetry).
ctrl += self.action_bias
# Deadzone # Deadzone
dz_pos, dz_neg = self.deadzone dz_pos, dz_neg = self.deadzone
if ctrl >= 0 and ctrl < dz_pos: if ctrl >= 0 and ctrl < dz_pos:
@@ -88,19 +105,25 @@ class ActuatorConfig:
def compute_motor_force(self, vel: float, ctrl: float, def compute_motor_force(self, vel: float, ctrl: float,
friction_scale: float = 1.0, friction_scale: float = 1.0,
damping_scale: float = 1.0) -> float: damping_scale: float = 1.0) -> float:
"""Asymmetric friction, damping, drag, back-EMF → applied torque. """Asymmetric friction (Coulomb + Stribeck), damping, drag, back-EMF.
``friction_scale`` / ``damping_scale`` multiply the Coulomb-friction ``friction_scale`` / ``damping_scale`` multiply the friction and
and viscous-damping terms for per-env domain randomization viscous-damping terms for per-env domain randomization
(1.0 = no randomization, the default used by sysid). (1.0 = no randomization, the default used by sysid).
""" """
torque = 0.0 torque = 0.0
# Coulomb friction (direction-dependent) # Coulomb + Stribeck friction (direction-dependent). The Stribeck
# boost adds extra friction at low speed that decays as exp(-(v/vs)²)
# — crucial for cheap brushed motors near standstill.
fl_pos, fl_neg = self.frictionloss fl_pos, fl_neg = self.frictionloss
if abs(vel) > 1e-6: if abs(vel) > 1e-6:
fl = (fl_pos if vel > 0 else fl_neg) * friction_scale fl = fl_pos if vel > 0 else fl_neg
torque -= math.copysign(fl, vel) if self.stribeck_friction_boost > 0:
fl += self.stribeck_friction_boost * math.exp(
-((abs(vel) / self.stribeck_vel) ** 2)
)
torque -= math.copysign(fl * friction_scale, vel)
# Viscous damping (direction-dependent) # Viscous damping (direction-dependent)
damp = (self.damping[0] if vel > 0 else self.damping[1]) * damping_scale damp = (self.damping[0] if vel > 0 else self.damping[1]) * damping_scale
@@ -117,20 +140,26 @@ class ActuatorConfig:
return max(-10.0, min(10.0, torque)) return max(-10.0, min(10.0, torque))
def transform_action(self, action): def transform_action(self, action):
"""Vectorised deadzone + gear compensation for a torch batch.""" """Vectorised clip + bias + deadzone + gear compensation (torch batch).
Must produce the same result as ``transform_ctrl`` element-wise.
"""
action = action.clamp(self.ctrl_range[0], self.ctrl_range[1])
action = action + self.action_bias
dz_pos, dz_neg = self.deadzone dz_pos, dz_neg = self.deadzone
if dz_pos > 0 or dz_neg > 0: if dz_pos > 0 or dz_neg > 0:
action = action.clone()
pos_dead = (action >= 0) & (action < dz_pos) pos_dead = (action >= 0) & (action < dz_pos)
neg_dead = (action < 0) & (action > -dz_neg) neg_dead = (action < 0) & (action > -dz_neg)
action[pos_dead | neg_dead] = 0.0 action = action.masked_fill(pos_dead | neg_dead, 0.0)
gear_avg = self.gear_avg gear_avg = self.gear_avg
if gear_avg > 1e-8 and self.gear[0] != self.gear[1]: if gear_avg > 1e-8 and self.gear[0] != self.gear[1]:
action = action.clone() if dz_pos == 0 and dz_neg == 0 else action
pos = action >= 0 pos = action >= 0
action[pos] *= self.gear[0] / gear_avg action = torch.where(
action[~pos] *= self.gear[1] / gear_avg pos, action * (self.gear[0] / gear_avg),
action * (self.gear[1] / gear_avg),
)
return action return action
@@ -176,9 +205,18 @@ def load_robot_config(robot_dir: str | Path) -> RobotConfig:
if not urdf_path.exists(): if not urdf_path.exists():
raise FileNotFoundError(f"URDF not found: {urdf_path}") raise FileNotFoundError(f"URDF not found: {urdf_path}")
# Parse actuators # Parse actuators — ignore unknown keys (newer sysid exports may add
# fields before the loader learns about them) instead of crashing.
known_fields = {f.name for f in dataclasses.fields(ActuatorConfig)}
actuators = [] actuators = []
for a in raw.get("actuators", []): for a in raw.get("actuators", []):
unknown = set(a) - known_fields
if unknown:
log.warning(
"robot_yaml_unknown_actuator_keys",
keys=sorted(unknown), file=str(yaml_path),
)
a = {k: v for k, v in a.items() if k in known_fields}
if "ctrl_range" in a: if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"]) a["ctrl_range"] = tuple(a["ctrl_range"])
for key in ("gear", "deadzone", "damping", "frictionloss"): for key in ("gear", "deadzone", "damping", "frictionloss"):

View File

@@ -14,8 +14,7 @@ T = TypeVar("T")
class BaseRunnerConfig: class BaseRunnerConfig:
num_envs: int = 1 num_envs: int = 1
device: str = "cpu" device: str = "cpu"
history_length: int = 0 # 0 = no history (single obs), >0 = RMA-style history_length: int = 0 # 0 = plain obs, >0 = append (obs, action) history
rma_mode: str = "none" # "none" | "teacher" | "deploy"
# ── Domain randomization (sim-to-real) ───────────────────────── # ── Domain randomization (sim-to-real) ─────────────────────────
# Empty dict = disabled (every field below is a no-op). Supported keys: # Empty dict = disabled (every field below is a no-op). Supported keys:
@@ -25,7 +24,8 @@ class BaseRunnerConfig:
# friction_scale: [lo, hi] — per-env multiplier on Coulomb friction # friction_scale: [lo, hi] — per-env multiplier on Coulomb friction
# damping_scale: [lo, hi] — per-env multiplier on viscous damping # damping_scale: [lo, hi] — per-env multiplier on viscous damping
# torque_scale: [lo, hi] — per-env multiplier on applied motor torque # torque_scale: [lo, hi] — per-env multiplier on applied motor torque
# The randomized factors are exposed as privileged_obs (μ) for RMA. # With history_length > 0 the policy can implicitly infer the sampled
# dynamics from the recent (obs, action) window — end-to-end adaptation.
domain_rand: dict = dataclasses.field(default_factory=dict) domain_rand: dict = dataclasses.field(default_factory=dict)
class BaseRunner(abc.ABC, Generic[T]): class BaseRunner(abc.ABC, Generic[T]):
@@ -50,20 +50,13 @@ class BaseRunner(abc.ABC, Generic[T]):
) )
# ── Domain randomization (latency / sensor noise / dynamics) ─ # ── Domain randomization (latency / sensor noise / dynamics) ─
# Must precede the RMA block: teacher mode reads privileged_dim,
# which is derived from the randomized-factor count below.
self._setup_domain_rand() self._setup_domain_rand()
# ── RMA mode ──────────────────────────────────────────── # ── History buffer (implicit adaptation input) ────────────
self._rma_mode: str = getattr(self.config, "rma_mode", "none")
# ── History buffer (used in "deploy" and legacy "none" modes) ─
self._history_len: int = getattr(self.config, "history_length", 0) self._history_len: int = getattr(self.config, "history_length", 0)
if self._history_len > 0: if self._history_len > 0:
obs_dim = self.observation_space.shape[0] obs_dim = self.observation_space.shape[0]
act_dim = self.action_space.shape[0] act_dim = self.action_space.shape[0]
self._history_obs_dim = obs_dim
self._history_act_dim = act_dim
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action) self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
# Ring buffer: (num_envs, history_length, obs_dim + act_dim) # Ring buffer: (num_envs, history_length, obs_dim + act_dim)
self._history_buf = torch.zeros( self._history_buf = torch.zeros(
@@ -71,27 +64,9 @@ class BaseRunner(abc.ABC, Generic[T]):
device=self.config.device, device=self.config.device,
) )
# ── Observation space augmentation ─────────────────────── # Policy obs = [raw_obs, history_flat]
from gymnasium import spaces from gymnasium import spaces
raw_obs_dim = self.observation_space.shape[0] aug_dim = obs_dim + self._history_len * self._history_step_dim
if self._rma_mode == "teacher":
# Teacher gets [raw_obs, μ]. μ dim resolved after _sim_initialize.
mu_dim = getattr(self, "privileged_dim", 0)
if mu_dim > 0:
aug_dim = raw_obs_dim + mu_dim
self.observation_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(aug_dim,),
)
elif self._rma_mode == "deploy" and self._history_len > 0:
# Deploy gets [raw_obs, history_flat].
aug_dim = raw_obs_dim + self._history_len * self._history_step_dim
self.observation_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(aug_dim,),
)
elif self._rma_mode == "none" and self._history_len > 0:
# Legacy mode: [raw_obs, history_flat].
aug_dim = raw_obs_dim + self._history_len * self._history_step_dim
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(aug_dim,), low=-torch.inf, high=torch.inf, shape=(aug_dim,),
) )
@@ -116,6 +91,12 @@ class BaseRunner(abc.ABC, Generic[T]):
@abc.abstractmethod @abc.abstractmethod
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Reset the given envs; return FULL-batch (num_envs, nq/nv) state.
Returning the full batch (not just the reset envs) lets GPU
backends hand back zero-copy views without host synchronisation —
the caller indexes the reset rows itself.
"""
... ...
def _sim_close(self) -> None: def _sim_close(self) -> None:
@@ -128,10 +109,10 @@ class BaseRunner(abc.ABC, Generic[T]):
_SCALE_FIELDS = ("friction_scale", "damping_scale", "torque_scale") _SCALE_FIELDS = ("friction_scale", "damping_scale", "torque_scale")
def _setup_domain_rand(self) -> None: def _setup_domain_rand(self) -> None:
"""Parse the domain_rand config into per-env buffers + the μ layout. """Parse the domain_rand config into per-env buffers.
All buffers are no-ops when ``domain_rand`` is empty: scales are 1.0, All buffers are no-ops when ``domain_rand`` is empty: scales are 1.0,
delay is 0, noise std is 0, and privileged_dim is 0. delay is 0 and noise std is 0.
""" """
dr = dict(getattr(self.config, "domain_rand", {}) or {}) dr = dict(getattr(self.config, "domain_rand", {}) or {})
n = self.config.num_envs n = self.config.num_envs
@@ -145,27 +126,21 @@ class BaseRunner(abc.ABC, Generic[T]):
self._dr_scales: dict[str, torch.Tensor] = { self._dr_scales: dict[str, torch.Tensor] = {
f: torch.ones(n, device=dev) for f in self._SCALE_FIELDS f: torch.ones(n, device=dev) for f in self._SCALE_FIELDS
} }
self._dr_scale_ranges: dict[str, tuple[float, float]] = {}
# Per-env integer action delay (in control steps).
self._dr_delay = torch.zeros(n, dtype=torch.long, device=dev)
# Spec list — its order fixes the privileged μ vector layout.
self._dr_specs: list[tuple[str, float, float]] = []
delay_range = dr.get("action_delay_steps")
if delay_range:
self._dr_specs.append(
("action_delay_steps", float(delay_range[0]), float(delay_range[1]))
)
self._max_delay = int(delay_range[1])
else:
self._max_delay = 0
for f in self._SCALE_FIELDS: for f in self._SCALE_FIELDS:
rng = dr.get(f) rng = dr.get(f)
if rng: if rng:
self._dr_specs.append((f, float(rng[0]), float(rng[1]))) self._dr_scale_ranges[f] = (float(rng[0]), float(rng[1]))
self._mu_dim = len(self._dr_specs) # Per-env integer action delay (in control steps).
self._dr_mu = torch.zeros(n, self._mu_dim, device=dev) self._dr_delay = torch.zeros(n, dtype=torch.long, device=dev)
delay_range = dr.get("action_delay_steps")
if delay_range:
self._delay_range = (int(delay_range[0]), int(delay_range[1]))
self._max_delay = int(delay_range[1])
else:
self._delay_range = (0, 0)
self._max_delay = 0
# Action-delay ring buffer: (num_envs, max_delay + 1, act_dim). # Action-delay ring buffer: (num_envs, max_delay + 1, act_dim).
if self._max_delay > 0: if self._max_delay > 0:
@@ -176,23 +151,17 @@ class BaseRunner(abc.ABC, Generic[T]):
def _resample_domain_rand(self, env_ids: torch.Tensor) -> None: def _resample_domain_rand(self, env_ids: torch.Tensor) -> None:
"""Sample fresh per-env DR factors (call on every (re)set).""" """Sample fresh per-env DR factors (call on every (re)set)."""
if self._mu_dim == 0 or env_ids.numel() == 0: if env_ids.numel() == 0:
return return
dev = self.config.device dev = self.config.device
for j, (name, lo, hi) in enumerate(self._dr_specs): for name, (lo, hi) in self._dr_scale_ranges.items():
if name == "action_delay_steps":
vals = torch.randint(
int(lo), int(hi) + 1, (env_ids.numel(),), device=dev,
)
self._dr_delay[env_ids] = vals
raw = vals.float()
else:
vals = torch.rand(env_ids.numel(), device=dev) * (hi - lo) + lo vals = torch.rand(env_ids.numel(), device=dev) * (hi - lo) + lo
self._dr_scales[name][env_ids] = vals self._dr_scales[name][env_ids] = vals
raw = vals if self._max_delay > 0:
# Normalize each factor to ~[-1, 1] for the privileged μ vector. self._dr_delay[env_ids] = torch.randint(
span = (hi - lo) if (hi - lo) > 1e-9 else 1.0 self._delay_range[0], self._delay_range[1] + 1,
self._dr_mu[env_ids, j] = 2.0 * (raw - lo) / span - 1.0 (env_ids.numel(),), device=dev,
)
def _reset_action_buffer(self, env_ids: torch.Tensor) -> None: def _reset_action_buffer(self, env_ids: torch.Tensor) -> None:
if self._max_delay > 0: if self._max_delay > 0:
@@ -225,35 +194,10 @@ class BaseRunner(abc.ABC, Generic[T]):
nqpos, nqvel = self._add_sensor_noise(qpos, qvel) nqpos, nqvel = self._add_sensor_noise(qpos, qvel)
return self.env.compute_observations(self.env.build_state(nqpos, nqvel)) return self.env.compute_observations(self.env.build_state(nqpos, nqvel))
# ── Privileged observation interface ─────────────────────────
@property
def privileged_dim(self) -> int:
"""Number of randomized DR factors exposed as μ (0 if DR disabled)."""
return getattr(self, "_mu_dim", 0)
@property
def privileged_obs(self) -> torch.Tensor:
"""Per-env normalized DR factors μ ∈ [-1, 1] for RMA supervision."""
if getattr(self, "_mu_dim", 0) > 0:
return self._dr_mu
return torch.zeros(self.config.num_envs, 0, device=self.config.device)
# ── Observation augmentation ───────────────────────────────── # ── Observation augmentation ─────────────────────────────────
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor: def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
"""Augment raw obs based on RMA mode. """Append the flattened (obs, action) history when enabled."""
teacher: [raw_obs, μ]
deploy: [raw_obs, history_flat]
none: [raw_obs, history_flat] (legacy, or plain obs if no history)
"""
if self._rma_mode == "teacher":
mu = self.privileged_obs
if mu.shape[-1] > 0:
return torch.cat([obs, mu], dim=-1)
return obs
# deploy / none: concatenate history
if self._history_len <= 0: if self._history_len <= 0:
return obs return obs
hist_flat = self._history_buf.reshape(obs.shape[0], -1) hist_flat = self._history_buf.reshape(obs.shape[0], -1)
@@ -292,6 +236,11 @@ class BaseRunner(abc.ABC, Generic[T]):
return self._augment_obs(obs), {} return self._augment_obs(obs), {}
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
prev_actions = (
self._last_actions
if self._last_actions is not None
else torch.zeros_like(actions)
)
self._last_actions = actions self._last_actions = actions
# Latency: the simulator applies a (per-env) delayed action. # Latency: the simulator applies a (per-env) delayed action.
sim_actions = self._apply_action_delay(actions) sim_actions = self._apply_action_delay(actions)
@@ -301,7 +250,7 @@ class BaseRunner(abc.ABC, Generic[T]):
# Reward / termination use the TRUE state (no sensor noise) so the # Reward / termination use the TRUE state (no sensor noise) so the
# learning signal and safety checks stay clean. # learning signal and safety checks stay clean.
clean_state = self.env.build_state(qpos, qvel) clean_state = self.env.build_state(qpos, qvel)
rewards = self.env.compute_rewards(clean_state, actions) rewards = self.env.compute_rewards(clean_state, actions, prev_actions)
terminated = self.env.compute_terminations(clean_state) terminated = self.env.compute_terminations(clean_state)
truncated = self.env.compute_truncations(self.step_counts) truncated = self.env.compute_truncations(self.step_counts)
@@ -313,10 +262,6 @@ class BaseRunner(abc.ABC, Generic[T]):
info: dict[str, Any] = {} info: dict[str, Any] = {}
# Expose privileged DR params μ for RMA supervision (this step's dynamics).
if self.privileged_dim > 0:
info["privileged_obs"] = self.privileged_obs.clone()
done = terminated | truncated done = terminated | truncated
done_ids = done.nonzero(as_tuple=False).squeeze(-1) done_ids = done.nonzero(as_tuple=False).squeeze(-1)
@@ -327,11 +272,14 @@ class BaseRunner(abc.ABC, Generic[T]):
# New episode → fresh dynamics + cleared latency buffer. # New episode → fresh dynamics + cleared latency buffer.
self._resample_domain_rand(done_ids) self._resample_domain_rand(done_ids)
self._reset_action_buffer(done_ids) self._reset_action_buffer(done_ids)
reset_qpos, reset_qvel = self._sim_reset(done_ids) full_qpos, full_qvel = self._sim_reset(done_ids)
self.step_counts[done_ids] = 0 self.step_counts[done_ids] = 0
self._reset_history(done_ids) self._reset_history(done_ids)
obs[done_ids] = self._compute_obs(reset_qpos, reset_qvel) # _sim_reset returns the full batch — index the reset rows here.
obs[done_ids] = self._compute_obs(
full_qpos[done_ids], full_qvel[done_ids],
)
# skrl expects (num_envs, 1) for rewards/terminated/truncated # skrl expects (num_envs, 1) for rewards/terminated/truncated
return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info

View File

@@ -1,53 +0,0 @@
import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class CartPoleState:
cart_pos: torch.float # (num_envs,)
cart_vel: torch.float # (num_envs,)
pole_angle: torch.float # (num_envs,)
pole_vel: torch.float # (num_envs,)
@dataclasses.dataclass
class CartPoleConfig(BaseEnvConfig):
"""CartPole task config. All values come from Hydra YAML."""
angle_threshold: float = 0.418 # ~24 degrees
cart_limit: float = 2.4
reward_alive: float = 1.0
reward_pole_upright_scale: float = 1.0
reward_action_penalty_scale: float = 0.01
class CartPoleEnv(BaseEnv[CartPoleConfig]):
def __init__(self, config: CartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> torch.Tensor:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
@property
def action_space(self) -> torch.Tensor:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
return CartPoleState(
cart_pos=qpos[:, 0],
cart_vel=qvel[:, 0],
pole_angle=qpos[:, 1],
pole_vel=qvel[:, 1],
)
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
return self.config.reward_alive + upright - action_penalty
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
return pole_fallen | cart_out_of_bounds

View File

@@ -1,9 +1,12 @@
import dataclasses import dataclasses
import math import math
import numpy as np
import torch import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces from gymnasium import spaces
from src.core.env import BaseEnv, BaseEnvConfig
@dataclasses.dataclass @dataclasses.dataclass
class RotaryCartPoleState: class RotaryCartPoleState:
@@ -28,6 +31,8 @@ class RotaryCartPoleConfig(BaseEnvConfig):
motor_vel_penalty: float = 0.01 # penalise high motor angular velocity motor_vel_penalty: float = 0.01 # penalise high motor angular velocity
motor_angle_penalty: float = 0.05 # penalise deviation from centre motor_angle_penalty: float = 0.05 # penalise deviation from centre
action_penalty: float = 0.05 # penalise large actions (energy cost) action_penalty: float = 0.05 # penalise large actions (energy cost)
action_rate_penalty: float = 0.01 # penalise action changes (smoothness —
# critical with ~100 ms real motor lag)
# ── Initial state randomisation ── # ── Initial state randomisation ──
pendulum_init_range_deg: float = 180.0 # pendulum starts in [-range, +range] pendulum_init_range_deg: float = 180.0 # pendulum starts in [-range, +range]
@@ -87,7 +92,12 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
# ── Rewards ────────────────────────────────────────────────── # ── Rewards ──────────────────────────────────────────────────
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor: def compute_rewards(
self,
state: RotaryCartPoleState,
actions: torch.Tensor,
prev_actions: torch.Tensor,
) -> torch.Tensor:
# Upright shaping ∈ [0, 1]: 0 hanging down (θ=0), 1 fully upright (θ=π). # Upright shaping ∈ [0, 1]: 0 hanging down (θ=0), 1 fully upright (θ=π).
# Non-negative by design so *surviving* always beats ending the episode early # Non-negative by design so *surviving* always beats ending the episode early
# (otherwise the optimum is to slam the arm into the ±limit — "suicide policy"). # (otherwise the optimum is to slam the arm into the ±limit — "suicide policy").
@@ -115,6 +125,11 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
# Penalise large actions (energy efficiency / smoother control) # Penalise large actions (energy efficiency / smoother control)
reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2) reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2)
# Penalise rapid action changes — a jittery policy is unrealisable
# through the real motor's ~100 ms lag and excites unmodeled dynamics.
action_rate = (actions - prev_actions).squeeze(-1).pow(2)
reward = reward - self.config.action_rate_penalty * action_rate
# Penalty for exceeding motor angle limit (episode also terminates) # Penalty for exceeding motor angle limit (episode also terminates)
limit_rad = math.radians(self.config.motor_angle_limit_deg) limit_rad = math.radians(self.config.motor_angle_limit_deg)
exceeded = state.motor_angle.abs() >= limit_rad exceeded = state.motor_angle.abs() >= limit_rad
@@ -122,6 +137,22 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
return reward return reward
# ── Initial state randomization ──────────────────────────────
def initial_state_ranges(
self, nq: int, nv: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Small motor perturbation; wide pendulum angle (swing-up task)."""
qpos_lo = np.full(nq, -0.05)
qpos_hi = np.full(nq, 0.05)
qvel_lo = np.full(nv, -0.05)
qvel_hi = np.full(nv, 0.05)
pend_range = math.radians(self.config.pendulum_init_range_deg)
if pend_range > 0 and nq >= 2:
qpos_lo[1] = -pend_range
qpos_hi[1] = pend_range
return qpos_lo, qpos_hi, qvel_lo, qvel_hi
# ── Terminations ───────────────────────────────────────────── # ── Terminations ─────────────────────────────────────────────
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor: def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:

View File

@@ -11,8 +11,9 @@ class HistoryEncoder(nn.Module):
Output: (batch, embedding_dim) Output: (batch, embedding_dim)
Architecture: two temporal conv layers → global average pool → linear. Architecture: two temporal conv layers → global average pool → linear.
Used as the adaptation module φ in RMA Phase 2 (deploy mode) to Lets the policy implicitly infer the current dynamics (friction, torque
predict the latent ẑ from recent dynamics. scale, latency, …) from how the system responded to recent actions —
end-to-end adaptation when trained under domain randomization.
""" """
def __init__( def __init__(
@@ -42,35 +43,13 @@ class HistoryEncoder(nn.Module):
return self.fc(x) return self.fc(x)
class EnvironmentEncoder(nn.Module):
"""MLP that compresses privileged DR parameters μ into a latent z.
Used in RMA Phase 1 (teacher mode): z = e(μ).
"""
def __init__(self, mu_dim: int, latent_dim: int = 8, hidden: int = 64) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(mu_dim, hidden),
nn.ELU(),
nn.Linear(hidden, latent_dim),
)
def forward(self, mu: torch.Tensor) -> torch.Tensor:
"""mu: (batch, mu_dim) → z: (batch, latent_dim)."""
return self.net(mu)
class SharedMLP(GaussianMixin, DeterministicMixin, Model): class SharedMLP(GaussianMixin, DeterministicMixin, Model):
"""Shared policy/value network with RMA support. """Shared policy/value network with an optional history encoder.
rma_mode: With ``history_length > 0`` the input states are expected to be
"none" legacy: optional history encoder, plain obs, backward compat. ``[raw_obs, history_flat]`` (as produced by ``BaseRunner``); the history
"teacher" Phase 1: env_encoder(μ) → z, backbone input = [raw_obs, z]. window is compressed by a :class:`HistoryEncoder` and concatenated with
"deploy" Phase 2+: adaptation_module(history) → ẑ, input = [raw_obs, ẑ]. the raw observation before the shared backbone.
Teacher and deploy modes produce the same backbone in_dim
(raw_obs_dim + latent_dim), so weights transfer cleanly.
""" """
def __init__( def __init__(
@@ -84,14 +63,10 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
min_log_std: float = -2.0, min_log_std: float = -2.0,
max_log_std: float = 2.0, max_log_std: float = 2.0,
initial_log_std: float = 0.0, initial_log_std: float = 0.0,
# ── Legacy (none mode) ─────────────────────────────────── # ── History encoder ──────────────────────────────────────
history_length: int = 0, history_length: int = 0,
raw_obs_dim: int = 0, raw_obs_dim: int = 0,
embedding_dim: int = 32, embedding_dim: int = 32,
# ── RMA modes ────────────────────────────────────────────
rma_mode: str = "none",
latent_dim: int = 8,
mu_dim: int = 0,
): ):
Model.__init__(self, observation_space, action_space, device) Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__( GaussianMixin.__init__(
@@ -99,40 +74,13 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
) )
DeterministicMixin.__init__(self, clip_actions) DeterministicMixin.__init__(self, clip_actions)
self._rma_mode = rma_mode
self._history_length = history_length self._history_length = history_length
self._raw_obs_dim = raw_obs_dim self._raw_obs_dim = raw_obs_dim
self._embedding_dim = embedding_dim self._embedding_dim = embedding_dim
self._latent_dim = latent_dim
self._mu_dim = mu_dim
# ── Build encoder + determine backbone in_dim ──────────── self.history_encoder: HistoryEncoder | None = None
self.env_encoder: EnvironmentEncoder | None = None if history_length > 0 and raw_obs_dim > 0:
self.adaptation_module: HistoryEncoder | None = None step_dim = raw_obs_dim + self.num_actions
self.history_encoder: HistoryEncoder | None = None # legacy
if rma_mode == "teacher":
assert mu_dim > 0 and raw_obs_dim > 0
self.env_encoder = EnvironmentEncoder(
mu_dim=mu_dim, latent_dim=latent_dim,
)
in_dim = raw_obs_dim + latent_dim
elif rma_mode == "deploy":
assert history_length > 0 and raw_obs_dim > 0
act_dim = self.num_actions
step_dim = raw_obs_dim + act_dim
self.adaptation_module = HistoryEncoder(
history_length=history_length,
step_dim=step_dim,
embedding_dim=latent_dim,
)
in_dim = raw_obs_dim + latent_dim
elif history_length > 0 and raw_obs_dim > 0:
# Legacy "none" mode with history encoder.
act_dim = self.num_actions
step_dim = raw_obs_dim + act_dim
self.history_encoder = HistoryEncoder( self.history_encoder = HistoryEncoder(
history_length=history_length, history_length=history_length,
step_dim=step_dim, step_dim=step_dim,
@@ -169,25 +117,10 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
return DeterministicMixin.act(self, inputs, role) return DeterministicMixin.act(self, inputs, role)
def _encode(self, states: torch.Tensor) -> torch.Tensor: def _encode(self, states: torch.Tensor) -> torch.Tensor:
"""Route through the correct encoder based on rma_mode.""" """Optionally split off and encode the history window."""
if self._rma_mode == "teacher": if self.history_encoder is None:
# states = [raw_obs, μ] return self.net(states)
obs = states[:, :self._raw_obs_dim]
mu = states[:, self._raw_obs_dim:]
z = self.env_encoder(mu)
return self.net(torch.cat([obs, z], dim=-1))
if self._rma_mode == "deploy":
# states = [raw_obs, history_flat]
obs = states[:, :self._raw_obs_dim]
hist_flat = states[:, self._raw_obs_dim:]
step_dim = self._raw_obs_dim + self.num_actions
history = hist_flat.reshape(-1, self._history_length, step_dim)
z_hat = self.adaptation_module(history)
return self.net(torch.cat([obs, z_hat], dim=-1))
# Legacy "none" mode.
if self.history_encoder is not None:
obs = states[:, :self._raw_obs_dim] obs = states[:, :self._raw_obs_dim]
hist_flat = states[:, self._raw_obs_dim:] hist_flat = states[:, self._raw_obs_dim:]
step_dim = self._raw_obs_dim + self.num_actions step_dim = self._raw_obs_dim + self.num_actions
@@ -195,8 +128,6 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
embedding = self.history_encoder(history) embedding = self.history_encoder(history)
return self.net(torch.cat([obs, embedding], dim=-1)) return self.net(torch.cat([obs, embedding], dim=-1))
return self.net(states)
def compute( def compute(
self, inputs: dict[str, torch.Tensor], role: str = "", self, inputs: dict[str, torch.Tensor], role: str = "",
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:

View File

@@ -88,7 +88,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
mujoco.mj_forward(self._mj_model, default_data) mujoco.mj_forward(self._mj_model, default_data)
self._default_mjx_data = mjx.put_data(self._mj_model, default_data) self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
# Step 4: Initialise all environments with small perturbations # Env-defined initial-state distribution (shared with the CPU
# runner) — baked into the JIT reset as constants.
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
self._nq, self._nv,
)
self._init_qpos_lo = jnp.asarray(qpos_lo)
self._init_qpos_hi = jnp.asarray(qpos_hi)
self._init_qvel_lo = jnp.asarray(qvel_lo)
self._init_qvel_hi = jnp.asarray(qvel_hi)
# Step 4: Initialise all environments with randomized states
self._rng = jax.random.PRNGKey(42) self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs) self._batch_data = self._make_batched_data(config.num_envs)
@@ -124,10 +134,16 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
) )
def _make_batched_data(self, n: int): def _make_batched_data(self, n: int):
"""Create *n* environments with small random perturbations.""" """Create *n* environments with env-defined initial randomization."""
self._rng, k1, k2 = jax.random.split(self._rng, 3) self._rng, k1, k2 = jax.random.split(self._rng, 3)
pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05) pq = jax.random.uniform(
pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05) k1, (n, self._nq),
minval=self._init_qpos_lo, maxval=self._init_qpos_hi,
)
pv = jax.random.uniform(
k2, (n, self._nv),
minval=self._init_qvel_lo, maxval=self._init_qvel_hi,
)
default = self._default_mjx_data default = self._default_mjx_data
model = self._mjx_model model = self._mjx_model
@@ -154,11 +170,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
act_gs = jnp.array(lim.gear_sign) act_gs = jnp.array(lim.gear_sign)
# ── Motor model params (JAX arrays for JIT) ───────────────── # ── Motor model params (JAX arrays for JIT) ─────────────────
# Must stay in lock-step with ActuatorConfig.transform_ctrl() /
# compute_motor_force() in src/core/robot.py — sysid fits against
# the CPU implementation.
_has_motor = len(self._motor_info) > 0 _has_motor = len(self._motor_info) > 0
if _has_motor: if _has_motor:
acts = self._motor_acts acts = self._motor_acts
_ctrl_ids = jnp.array([c for c, _ in self._motor_info]) _ctrl_ids = jnp.array([c for c, _ in self._motor_info])
_qvel_ids = jnp.array([q for _, q in self._motor_info]) _qvel_ids = jnp.array([q for _, q in self._motor_info])
_ctrl_lo = jnp.array([a.ctrl_range[0] for a in acts])
_ctrl_hi = jnp.array([a.ctrl_range[1] for a in acts])
_bias = jnp.array([a.action_bias for a in acts])
_dz_pos = jnp.array([a.deadzone[0] for a in acts]) _dz_pos = jnp.array([a.deadzone[0] for a in acts])
_dz_neg = jnp.array([a.deadzone[1] for a in acts]) _dz_neg = jnp.array([a.deadzone[1] for a in acts])
_gear_pos = jnp.array([a.gear[0] for a in acts]) _gear_pos = jnp.array([a.gear[0] for a in acts])
@@ -166,6 +188,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
_gear_avg = jnp.array([a.gear_avg for a in acts]) _gear_avg = jnp.array([a.gear_avg for a in acts])
_fl_pos = jnp.array([a.frictionloss[0] for a in acts]) _fl_pos = jnp.array([a.frictionloss[0] for a in acts])
_fl_neg = jnp.array([a.frictionloss[1] for a in acts]) _fl_neg = jnp.array([a.frictionloss[1] for a in acts])
_strb_boost = jnp.array([a.stribeck_friction_boost for a in acts])
_strb_vel = jnp.array([a.stribeck_vel for a in acts])
_damp_pos = jnp.array([a.damping[0] for a in acts]) _damp_pos = jnp.array([a.damping[0] for a in acts])
_damp_neg = jnp.array([a.damping[1] for a in acts]) _damp_neg = jnp.array([a.damping[1] for a in acts])
_visc_quad = jnp.array([a.viscous_quadratic for a in acts]) _visc_quad = jnp.array([a.viscous_quadratic for a in acts])
@@ -189,8 +213,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl) ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl)
if _has_motor: if _has_motor:
# Deadzone + asymmetric gear compensation # Clip → bias → deadzone asymmetric gear compensation
# (same order as ActuatorConfig.transform_ctrl).
mc = ctrl[:, _ctrl_ids] mc = ctrl[:, _ctrl_ids]
mc = jnp.clip(mc, _ctrl_lo, _ctrl_hi)
mc = mc + _bias
mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc) mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc)
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc) mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg) gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
@@ -205,8 +232,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
vel = d.qvel[:, _qvel_ids] vel = d.qvel[:, _qvel_ids]
mc = d.ctrl[:, _ctrl_ids] mc = d.ctrl[:, _ctrl_ids]
# Coulomb friction (direction-dependent) × DR scale # Coulomb + Stribeck friction (direction-dependent) × DR
fl = jnp.where(vel > 0, _fl_pos, _fl_neg) * fr fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
fl = fl + _strb_boost * jnp.exp(
-((jnp.abs(vel) / _strb_vel) ** 2)
)
fl = fl * fr
torque = -jnp.where( torque = -jnp.where(
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0, jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
) )
@@ -232,16 +263,23 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._jit_step = step_fn self._jit_step = step_fn
# ── Selective reset ───────────────────────────────────────── # ── Selective reset ─────────────────────────────────────────
init_qpos_lo = self._init_qpos_lo
init_qpos_hi = self._init_qpos_hi
init_qvel_lo = self._init_qvel_lo
init_qvel_hi = self._init_qvel_hi
@jax.jit @jax.jit
def reset_fn(data, mask, rng): def reset_fn(data, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3) rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0] ne = data.qpos.shape[0]
pq = jax.random.uniform( pq = jax.random.uniform(
k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05, k1, (ne, default.qpos.shape[0]),
minval=init_qpos_lo, maxval=init_qpos_hi,
) )
pv = jax.random.uniform( pv = jax.random.uniform(
k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05, k2, (ne, default.qvel.shape[0]),
minval=init_qvel_lo, maxval=init_qvel_hi,
) )
m = mask[:, None] # (num_envs, 1) broadcast helper m = mask[:, None] # (num_envs, 1) broadcast helper
@@ -293,11 +331,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous()) self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous()) self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
# Return only the reset environments' states # Return the FULL batch (BaseRunner indexes the reset envs in torch)
ids_np = env_ids.cpu().numpy() # — avoids a GPU→CPU sync + JAX gather on every step with a done env.
rq = self._batch_data.qpos[ids_np].astype(jnp.float32) qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
rv = self._batch_data.qvel[ids_np].astype(jnp.float32) qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
return torch.from_dlpack(rq), torch.from_dlpack(rv) return qpos, qvel
# ── Rendering ──────────────────────────────────────────────────── # ── Rendering ────────────────────────────────────────────────────

View File

@@ -1,5 +1,4 @@
import dataclasses import dataclasses
import math
import os import os
import tempfile import tempfile
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@@ -278,32 +277,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ids = env_ids.cpu().numpy() ids = env_ids.cpu().numpy()
n = len(ids)
qpos_batch = np.zeros((n, self._nq), dtype=np.float32) # Env-defined initial-state distribution (shared with the MJX runner).
qvel_batch = np.zeros((n, self._nv), dtype=np.float32) qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
self._nq, self._nv,
)
# Check if env has a pendulum_init_range_deg for wide pendulum randomisation for env_id in ids:
pend_range = getattr(self.env.config, "pendulum_init_range_deg", 0.0)
pend_range_rad = math.radians(pend_range) if pend_range > 0 else 0.0
for i, env_id in enumerate(ids):
data = self._data[env_id] data = self._data[env_id]
mujoco.mj_resetData(self._model, data) mujoco.mj_resetData(self._model, data)
# Small random perturbation for motor angle + velocity data.qpos[:] += np.random.uniform(qpos_lo, qpos_hi)
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq) data.qvel[:] += np.random.uniform(qvel_lo, qvel_hi)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Wide pendulum angle randomisation (overrides the small perturbation)
if pend_range_rad > 0 and self._nq >= 2:
data.qpos[1] = np.random.uniform(-pend_range_rad, pend_range_rad)
data.ctrl[:] = 0.0 data.ctrl[:] = 0.0
qpos_batch[i] = data.qpos # Full-batch return (see BaseRunner._sim_reset contract).
qvel_batch[i] = data.qvel qpos_batch = np.stack([d.qpos for d in self._data]).astype(np.float32)
qvel_batch = np.stack([d.qvel for d in self._data]).astype(np.float32)
return ( return (
torch.from_numpy(qpos_batch).to(self.device), torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device), torch.from_numpy(qvel_batch).to(self.device),

View File

@@ -22,19 +22,25 @@ log = structlog.get_logger()
def export_tuned_files( def export_tuned_files(
robot_path: str | Path, robot_path: str | Path,
params: dict[str, float], params: dict[str, float],
motor_params: dict[str, float] | None = None,
) -> tuple[Path, Path]: ) -> tuple[Path, Path]:
"""Write tuned URDF and robot.yaml files. """Write tuned URDF and robot.yaml files.
Parameters Parameters
---------- ----------
robot_path : robot asset directory (contains robot.yaml + *.urdf) robot_path : robot asset directory (contains robot.yaml + *.urdf)
params : dict of parameter name → tuned value (unified, all 28 params) params : dict of parameter name → tuned value (the optimised set)
motor_params : locked motor parameters merged underneath ``params``
(``params`` wins on conflicts) so the exported YAML always has a
complete motor model
Returns Returns
------- -------
(tuned_urdf_path, tuned_robot_yaml_path) (tuned_urdf_path, tuned_robot_yaml_path)
""" """
robot_path = Path(robot_path).resolve() robot_path = Path(robot_path).resolve()
if motor_params:
params = {**motor_params, **params}
# ── Load originals ─────────────────────────────────────────── # ── Load originals ───────────────────────────────────────────
robot_yaml_path = robot_path / "robot.yaml" robot_yaml_path = robot_path / "robot.yaml"

View File

@@ -7,12 +7,13 @@ the simulated trajectory for comparison with the real recording.
This module is the inner loop of the CMA-ES optimizer: it is called once This module is the inner loop of the CMA-ES optimizer: it is called once
per candidate parameter vector per generation. per candidate parameter vector per generation.
Motor parameters are **locked** from the motor-only sysid result. Motor parameters are **locked** from the unified sysid result.
The optimizer only fits The optimizer only fits
pendulum/arm inertial parameters, pendulum joint dynamics, and pendulum/arm inertial parameters, pendulum joint dynamics, and
``ctrl_limit``. The asymmetric motor model (deadzone, gear compensation, ``ctrl_limit``. The asymmetric motor model (bias, deadzone, gear
Coulomb friction, viscous damping, quadratic drag, back-EMF) is applied compensation, Coulomb + Stribeck friction, viscous damping) is applied
via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()``. via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()``
the same code the training runners use, so sim == sysid by construction.
""" """
from __future__ import annotations from __future__ import annotations
@@ -32,23 +33,25 @@ from src.runners.mujoco import ActuatorLimits, load_mujoco_model
from src.sysid._urdf import patch_link_inertials from src.sysid._urdf import patch_link_inertials
# ── Locked motor parameters (from motor-only sysid) ──────────────── # ── Locked motor parameters (from the unified sysid) ────────────────
# These are FIXED and not optimised. They come from the 12-param model # These are FIXED and not optimised. They come from the unified
# in robot.yaml (from motor-only sysid, cost 0.862). # 28-param sysid run (assets/rotary_cartpole/sysid_result.json,
# cost 0.925) — Stribeck friction + action bias + ~96 ms motor lag.
LOCKED_MOTOR_PARAMS: dict[str, float] = { LOCKED_MOTOR_PARAMS: dict[str, float] = {
"actuator_gear_pos": 0.424182, "actuator_gear_pos": 0.846499,
"actuator_gear_neg": 0.425031, "actuator_gear_neg": 1.183733,
"actuator_filter_tau": 0.00503506, "actuator_filter_tau": 0.096263,
"motor_damping_pos": 0.00202682, "motor_damping_pos": 0.013165,
"motor_damping_neg": 0.0146651, "motor_damping_neg": 0.015452,
"motor_armature": 0.00277342, "motor_armature": 0.001676,
"motor_frictionloss_pos": 0.0573282, "motor_frictionloss_pos": 0.014244,
"motor_frictionloss_neg": 0.0533549, "motor_frictionloss_neg": 0.001005,
"viscous_quadratic": 0.000285329, "stribeck_friction_boost": 0.068594,
"back_emf_gain": 0.00675809, "stribeck_vel": 5.279594,
"motor_deadzone_pos": 0.141291, "motor_deadzone_pos": 0.181097,
"motor_deadzone_neg": 0.0780148, "motor_deadzone_neg": 0.202072,
"action_bias": 0.056566,
} }
@@ -190,26 +193,34 @@ def _build_model(
act_cfg = robot_yaml["actuators"][0] act_cfg = robot_yaml["actuators"][0]
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0]) ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
# The fitted ctrl_limit overrides the YAML ctrl_range so the rollout
# saturates at exactly the identified PWM bound.
if "ctrl_limit" in params:
ctrl_lo, ctrl_hi = -params["ctrl_limit"], params["ctrl_limit"]
actuator = ActuatorConfig( actuator = ActuatorConfig(
joint=act_cfg["joint"], joint=act_cfg["joint"],
type="motor", type="motor",
gear=(gear_pos, gear_neg), gear=(gear_pos, gear_neg),
ctrl_range=(ctrl_lo, ctrl_hi), ctrl_range=(ctrl_lo, ctrl_hi),
deadzone=( deadzone=(
motor_params.get("motor_deadzone_pos", 0.141), motor_params.get("motor_deadzone_pos", 0.181),
motor_params.get("motor_deadzone_neg", 0.078), motor_params.get("motor_deadzone_neg", 0.202),
), ),
damping=( damping=(
motor_params.get("motor_damping_pos", 0.002), motor_params.get("motor_damping_pos", 0.013),
motor_params.get("motor_damping_neg", 0.015), motor_params.get("motor_damping_neg", 0.015),
), ),
frictionloss=( frictionloss=(
motor_params.get("motor_frictionloss_pos", 0.057), motor_params.get("motor_frictionloss_pos", 0.014),
motor_params.get("motor_frictionloss_neg", 0.053), motor_params.get("motor_frictionloss_neg", 0.001),
), ),
filter_tau=motor_params.get("actuator_filter_tau", 0.005), filter_tau=motor_params.get("actuator_filter_tau", 0.096),
viscous_quadratic=motor_params.get("viscous_quadratic", 0.000285), viscous_quadratic=motor_params.get("viscous_quadratic", 0.0),
back_emf_gain=motor_params.get("back_emf_gain", 0.00676), back_emf_gain=motor_params.get("back_emf_gain", 0.0),
stribeck_friction_boost=motor_params.get("stribeck_friction_boost", 0.0),
stribeck_vel=motor_params.get("stribeck_vel", 2.0),
action_bias=motor_params.get("action_bias", 0.0),
) )
robot = RobotConfig( robot = RobotConfig(
@@ -276,7 +287,6 @@ def rollout(
mujoco.mj_resetData(model, data) mujoco.mj_resetData(model, data)
n = len(actions) n = len(actions)
ctrl_limit = params.get("ctrl_limit", 0.588)
sim_motor_angle = np.zeros(n, dtype=np.float64) sim_motor_angle = np.zeros(n, dtype=np.float64)
sim_motor_vel = np.zeros(n, dtype=np.float64) sim_motor_vel = np.zeros(n, dtype=np.float64)
@@ -286,8 +296,8 @@ def rollout(
limits = ActuatorLimits(model) limits = ActuatorLimits(model)
for i in range(n): for i in range(n):
action = max(-ctrl_limit, min(ctrl_limit, float(actions[i]))) # transform_ctrl clips to the (fitted) ctrl_range internally.
ctrl = actuator.transform_ctrl(action) ctrl = actuator.transform_ctrl(float(actions[i]))
data.ctrl[0] = ctrl data.ctrl[0] = ctrl
for _ in range(substeps): for _ in range(substeps):
@@ -378,7 +388,6 @@ def windowed_rollout(
window_starts.append(idx) window_starts.append(idx)
current_t += window_duration current_t += window_duration
ctrl_limit = params.get("ctrl_limit", 0.588)
n_windows = len(window_starts) n_windows = len(window_starts)
for w, w_start in enumerate(window_starts): for w, w_start in enumerate(window_starts):
@@ -393,8 +402,8 @@ def windowed_rollout(
mujoco.mj_forward(model, data) mujoco.mj_forward(model, data)
for i in range(w_start, w_end): for i in range(w_start, w_end):
action = max(-ctrl_limit, min(ctrl_limit, float(actions[i]))) # transform_ctrl clips to the (fitted) ctrl_range internally.
ctrl = actuator.transform_ctrl(action) ctrl = actuator.transform_ctrl(float(actions[i]))
data.ctrl[0] = ctrl data.ctrl[0] = ctrl
for _ in range(substeps): for _ in range(substeps):

View File

@@ -31,6 +31,7 @@ class TrainerConfig:
clip_ratio: float = 0.2 clip_ratio: float = 0.2
value_loss_scale: float = 0.5 value_loss_scale: float = 0.5
entropy_loss_scale: float = 0.01 entropy_loss_scale: float = 0.01
kl_threshold: float = 0.01 # KL-adaptive LR target; 0 = fixed LR
hidden_sizes: tuple[int, ...] = (64, 64) hidden_sizes: tuple[int, ...] = (64, 64)
@@ -48,14 +49,10 @@ class TrainerConfig:
record_video_every: int = 10_000 # 0 = disabled record_video_every: int = 10_000 # 0 = disabled
record_video_fps: int = 0 # 0 = derive from sim dt×substeps record_video_fps: int = 0 # 0 = derive from sim dt×substeps
# History encoder (RMA-style adaptation) # History encoder (implicit adaptation). The window size comes from
history_length: int = 0 # 0 = disabled, >0 = temporal window size # the runner (runner.history_length) — single source of truth.
embedding_dim: int = 32 # history encoder output dimension embedding_dim: int = 32 # history encoder output dimension
# RMA (Rapid Motor Adaptation)
rma_mode: str = "none" # "none" | "teacher" | "deploy"
latent_dim: int = 8 # env encoder / adaptation latent dimension
# ── Video-recording trainer ────────────────────────────────────────── # ── Video-recording trainer ──────────────────────────────────────────
@@ -107,13 +104,18 @@ class VideoRecordingTrainer(SequentialTrainer):
else: else:
states = next_states states = next_states
# Periodic video recording # Periodic video recording. Recording steps the (shared) envs,
# so it returns a freshly reset observation — the training loop
# MUST continue from it, otherwise the recorded transitions no
# longer match the actual env state.
if ( if (
self._tcfg self._tcfg
and self._tcfg.record_video_every > 0 and self._tcfg.record_video_every > 0
and (timestep + 1) % self._tcfg.record_video_every == 0 and (timestep + 1) % self._tcfg.record_video_every == 0
): ):
self._record_video(timestep + 1) fresh_states = self._record_video(timestep + 1)
if fresh_states is not None:
states = fresh_states
# ── helpers ─────────────────────────────────────────────────────── # ── helpers ───────────────────────────────────────────────────────
@@ -125,14 +127,21 @@ class VideoRecordingTrainer(SequentialTrainer):
# SerialRunner has dt but no substeps — dt *is* the control period. # SerialRunner has dt but no substeps — dt *is* the control period.
return max(1, int(round(1.0 / (dt * substeps)))) return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None: def _record_video(self, timestep: int) -> torch.Tensor | None:
"""Record an eval episode and upload it to ClearML.
Returns the freshly reset observation the training loop should
continue from (the recording steps the shared envs), or ``None``
if even the final reset failed.
"""
try: try:
import imageio.v3 as iio import imageio.v3 as iio
except ImportError: except ImportError:
return iio = None
# Rendering needs a GL backend (EGL/OSMesa); never let a headless GL # Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
# failure crash training — log it and skip the video. # failure crash training — log it and skip the video.
if iio is not None:
try: try:
fps = self._get_fps() fps = self._get_fps()
max_steps = getattr(self.env.env.config, "max_steps", 500) max_steps = getattr(self.env.env.config, "max_steps", 500)
@@ -161,11 +170,19 @@ class VideoRecordingTrainer(SequentialTrainer):
"Training Video", f"step_{timestep}", "Training Video", f"step_{timestep}",
local_path=path, iteration=timestep, local_path=path, iteration=timestep,
) )
self.env.reset()
except Exception as exc: except Exception as exc:
log.warning("video_recording_failed", timestep=timestep, error=str(exc)) log.warning("video_recording_failed", timestep=timestep, error=str(exc))
# Always leave the envs freshly reset and hand the new observation
# back to the training loop.
try:
with torch.no_grad():
states, _ = self.env.reset()
return states
except Exception as exc:
log.warning("post_video_reset_failed", timestep=timestep, error=str(exc))
return None
# ── Main trainer ───────────────────────────────────────────────────── # ── Main trainer ─────────────────────────────────────────────────────
@@ -186,11 +203,11 @@ class Trainer:
device=device, device=device,
) )
# Determine raw obs dim (without history/privileged augmentation). # Determine raw obs dim (without history augmentation) and the
# history window size — both come from the runner so the model
# always matches the observation layout it produces.
raw_obs_dim = self.runner.env.observation_space.shape[0] raw_obs_dim = self.runner.env.observation_space.shape[0]
history_length = getattr(self.runner.config, "history_length", 0)
# Privileged dimension (μ) from the runner, if available.
mu_dim = getattr(self.runner, "privileged_dim", 0)
self.model = SharedMLP( self.model = SharedMLP(
observation_space=obs_space, observation_space=obs_space,
@@ -200,12 +217,9 @@ class Trainer:
initial_log_std=self.config.initial_log_std, initial_log_std=self.config.initial_log_std,
min_log_std=self.config.min_log_std, min_log_std=self.config.min_log_std,
max_log_std=self.config.max_log_std, max_log_std=self.config.max_log_std,
history_length=self.config.history_length, history_length=history_length,
raw_obs_dim=raw_obs_dim, raw_obs_dim=raw_obs_dim,
embedding_dim=self.config.embedding_dim, embedding_dim=self.config.embedding_dim,
rma_mode=self.config.rma_mode,
latent_dim=self.config.latent_dim,
mu_dim=mu_dim,
) )
models = {"policy": self.model, "value": self.model} models = {"policy": self.model, "value": self.model}
@@ -221,11 +235,20 @@ class Trainer:
"ratio_clip": self.config.clip_ratio, "ratio_clip": self.config.clip_ratio,
"value_loss_scale": self.config.value_loss_scale, "value_loss_scale": self.config.value_loss_scale,
"entropy_loss_scale": self.config.entropy_loss_scale, "entropy_loss_scale": self.config.entropy_loss_scale,
# Truncation (time limit) must bootstrap from the value function;
# without this the value target is biased at every max_steps cut.
"time_limit_bootstrap": True,
"state_preprocessor": RunningStandardScaler, "state_preprocessor": RunningStandardScaler,
"state_preprocessor_kwargs": {"size": obs_space, "device": device}, "state_preprocessor_kwargs": {"size": obs_space, "device": device},
"value_preprocessor": RunningStandardScaler, "value_preprocessor": RunningStandardScaler,
"value_preprocessor_kwargs": {"size": 1, "device": device}, "value_preprocessor_kwargs": {"size": 1, "device": device},
}) })
if self.config.kl_threshold > 0:
from skrl.resources.schedulers.torch import KLAdaptiveLR
agent_cfg["learning_rate_scheduler"] = KLAdaptiveLR
agent_cfg["learning_rate_scheduler_kwargs"] = {
"kl_threshold": self.config.kl_threshold,
}
# Wire up logging frequency: write_interval is in timesteps. # Wire up logging frequency: write_interval is in timesteps.
# log_interval=1 → log every PPO update (= every rollout_steps timesteps). # log_interval=1 → log every PPO update (= every rollout_steps timesteps).
agent_cfg["experiment"]["write_interval"] = self.config.log_interval agent_cfg["experiment"]["write_interval"] = self.config.log_interval

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
import sys
from pathlib import Path
# Make `src.*` importable regardless of pytest invocation directory.
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)

79
tests/test_reward.py Normal file
View File

@@ -0,0 +1,79 @@
"""Reward design tests — balancing must strictly dominate spinning."""
import math
from pathlib import Path
import pytest
import torch
from src.envs.rotary_cartpole import (
RotaryCartPoleConfig,
RotaryCartPoleEnv,
RotaryCartPoleState,
)
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
def _env() -> RotaryCartPoleEnv:
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
t = lambda v: torch.tensor([float(v)])
return RotaryCartPoleState(
motor_angle=t(motor), motor_vel=t(motor_vel),
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
)
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
a = torch.tensor([[float(action)]])
pa = torch.tensor([[float(prev_action)]])
return float(env.compute_rewards(state, a, pa)[0])
def test_balancing_beats_spinning_through_upright():
env = _env()
balanced = _state(pend=math.pi, pend_vel=0.0)
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
def test_average_spin_cycle_reward_below_balance():
"""Mean reward over a full revolution at high speed << balanced reward."""
env = _env()
angles = torch.linspace(0, 2 * math.pi, 32)
spin_rewards = [
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
]
mean_spin = sum(spin_rewards) / len(spin_rewards)
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
assert balanced > 3.0 * mean_spin
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
env = _env()
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
assert _reward(env, over_limit) == -10.0
assert bool(env.compute_terminations(over_limit)[0])
def test_action_rate_penalty_reduces_reward():
env = _env()
s = _state(pend=math.pi)
smooth = _reward(env, s, action=0.5, prev_action=0.5)
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
assert smooth > jerky
assert smooth - jerky == pytest.approx(
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
)
def test_initial_state_ranges_widen_pendulum_only():
env = _env()
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()

125
tests/test_robot_config.py Normal file
View File

@@ -0,0 +1,125 @@
"""Robot config loading + motor model unit tests."""
import math
from pathlib import Path
import pytest
import torch
from src.core.robot import ActuatorConfig, load_robot_config
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
# ── Loading ──────────────────────────────────────────────────────────
def test_canonical_robot_yaml_loads_full_motor_model():
robot = load_robot_config(ROBOT_DIR)
act = robot.actuators[0]
assert act.has_motor_model
# Tuned (unified sysid) values must survive the round-trip.
assert act.filter_tau == pytest.approx(0.096263)
assert act.stribeck_friction_boost == pytest.approx(0.068594)
assert act.stribeck_vel == pytest.approx(5.279594)
assert act.action_bias == pytest.approx(0.056566)
assert act.gear == pytest.approx((0.846499, 1.183733))
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
(tmp_path / "robot.yaml").write_text(
"urdf: dummy.urdf\n"
"actuators:\n"
" - joint: j\n"
" type: motor\n"
" gear: [1.0, 1.0]\n"
" some_future_field: 42\n"
)
robot = load_robot_config(tmp_path) # must not raise
assert robot.actuators[0].joint == "j"
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
@pytest.fixture
def act() -> ActuatorConfig:
return ActuatorConfig(
joint="m",
gear=(0.8, 1.2),
ctrl_range=(-0.6, 0.6),
deadzone=(0.15, 0.20),
frictionloss=(0.014, 0.001),
damping=(0.013, 0.015),
stribeck_friction_boost=0.07,
stribeck_vel=5.0,
action_bias=0.05,
)
def test_transform_ctrl_clips_to_ctrl_range(act):
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
out = act.transform_ctrl(1.0)
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
assert act.transform_ctrl(0.05) == 0.0
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
assert act.transform_ctrl(-0.15) == 0.0
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
assert pos == pytest.approx(0.55 * 0.8)
assert neg == pytest.approx(-0.45 * 1.2)
def test_transform_action_matches_transform_ctrl_elementwise(act):
vals = torch.linspace(-1.2, 1.2, 49)
batched = act.transform_action(vals.clone())
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
assert torch.allclose(batched, scalar, atol=1e-6)
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
def test_friction_opposes_motion(act):
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
def test_stribeck_boost_decays_with_speed(act):
"""Friction torque magnitude (minus damping) is higher near standstill."""
no_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
)
with_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
stribeck_friction_boost=0.07, stribeck_vel=5.0,
)
v_slow, v_fast = 0.1, 50.0
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
no_strb.compute_motor_force(v_slow, 0.0))
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
no_strb.compute_motor_force(v_fast, 0.0))
assert extra_slow == pytest.approx(
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
assert extra_fast < extra_slow
assert extra_fast == pytest.approx(0.0, abs=1e-9)
def test_friction_scale_dr_multiplies_friction(act):
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
# damping_scale=0 isolates the friction term
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
assert doubled == pytest.approx(2.0 * base)

173
tests/test_runner.py Normal file
View File

@@ -0,0 +1,173 @@
"""Runner integration tests — DR, history, action delay, init randomization.
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
is skipped when JAX isn't installed.
"""
import math
from pathlib import Path
import pytest
import torch
from src.core.registry import build_env
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
DR = {
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
}
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
cfg = MuJoCoRunnerConfig(
num_envs=num_envs,
device="cpu",
history_length=history_length,
domain_rand=domain_rand or {},
)
return MuJoCoRunner(env=env, config=cfg)
# ── Observation layout ───────────────────────────────────────────────
def test_obs_is_raw_plus_history():
runner = _runner(num_envs=2, history_length=3)
raw_dim = runner.env.observation_space.shape[0] # 6
step_dim = raw_dim + 1 # + action
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
obs, _ = runner.reset()
assert obs.shape == (2, raw_dim + 3 * step_dim)
# Fresh history must be zero.
assert torch.all(obs[:, raw_dim:] == 0)
actions = torch.full((2, 1), 0.3)
obs, rewards, term, trunc, info = runner.step(actions)
assert obs.shape == (2, raw_dim + 3 * step_dim)
assert rewards.shape == (2, 1)
# Newest history slot holds the commanded action.
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
runner.close()
def test_no_history_keeps_plain_obs():
runner = _runner(num_envs=2, history_length=0)
assert runner.observation_space.shape[0] == 6
runner.close()
# ── Domain randomization ─────────────────────────────────────────────
def test_dr_scales_sampled_within_ranges_and_resampled():
runner = _runner(num_envs=16, domain_rand=DR)
runner.reset()
for name, (lo, hi) in (
("friction_scale", (0.6, 1.6)),
("damping_scale", (0.6, 1.6)),
("torque_scale", (0.85, 1.15)),
):
vals = runner._dr_scales[name]
assert torch.all(vals >= lo) and torch.all(vals <= hi)
# 16 independent uniform samples are never all identical.
assert vals.std() > 0
before = runner._dr_scales["friction_scale"].clone()
runner.reset()
assert not torch.equal(before, runner._dr_scales["friction_scale"])
delays = runner._dr_delay
assert torch.all(delays >= 0) and torch.all(delays <= 2)
runner.close()
def test_dr_disabled_is_noop():
runner = _runner(num_envs=2, domain_rand={})
runner.reset()
for vals in runner._dr_scales.values():
assert torch.all(vals == 1.0)
assert runner._max_delay == 0
assert runner._qpos_noise_std == 0.0
runner.close()
def test_action_delay_buffer_returns_lagged_action():
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
runner.reset()
runner._dr_delay = torch.tensor([0, 1, 2])
runner._action_buf.zero_()
a1 = torch.tensor([[1.0], [1.0], [1.0]])
a2 = torch.tensor([[2.0], [2.0], [2.0]])
a3 = torch.tensor([[3.0], [3.0], [3.0]])
d1 = runner._apply_action_delay(a1)
d2 = runner._apply_action_delay(a2)
d3 = runner._apply_action_delay(a3)
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
runner.close()
# ── Initial-state randomization ──────────────────────────────────────
def test_wide_pendulum_init_actually_applied():
runner = _runner(num_envs=32)
qpos, _ = runner._sim_reset(torch.arange(32))
pend_angles = qpos[:, 1]
# With ±180° init range, samples must spread far beyond the old ±0.05.
assert pend_angles.abs().max() > 1.0
assert pend_angles.std() > 0.5
runner.close()
def test_sim_reset_returns_full_batch():
runner = _runner(num_envs=4)
runner.reset()
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
runner.close()
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
def test_mjx_runner_smoke():
pytest.importorskip("jax")
pytest.importorskip("mujoco.mjx")
from src.runners.mjx import MJXRunner, MJXRunnerConfig
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
runner = MJXRunner(
env=env,
config=MJXRunnerConfig(
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
),
)
obs, _ = runner.reset()
raw_dim = env.observation_space.shape[0]
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
for _ in range(3):
actions = torch.rand(4, 1) * 2 - 1
obs, rewards, term, trunc, _ = runner.step(actions)
assert torch.isfinite(obs).all()
assert torch.isfinite(rewards).all()
# Wide pendulum init must reach MJX resets too.
qpos, qvel = runner._sim_reset(torch.arange(4))
assert qpos.shape[0] == 4 # full batch
runner.close()

View File