Compare commits
24 Commits
main
...
feature/si
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e0836e1bc | |||
| a98e86ef66 | |||
| 4210b6cb53 | |||
| a6fbde798a | |||
| 56499ebe97 | |||
| b37cd26690 | |||
| 8cc84d6a21 | |||
| 8ed9afe583 | |||
| 5880997786 | |||
| ca0e7b8b03 | |||
| d3ed1c25ad | |||
| 3b2d6d08f9 | |||
| 23801857f4 | |||
| 3db68255f0 | |||
| 1a822bd82e | |||
| 4115447022 | |||
| 35223b3560 | |||
| 0f13086fee | |||
| 9813319275 | |||
| 70cd2cdd7d | |||
| 9be07d9186 | |||
| 26ccb1e902 | |||
| 15da0ef2fd | |||
| c753c369b4 |
29
.gitignore
vendored
29
.gitignore
vendored
@@ -1,3 +1,28 @@
|
||||
outputs/
|
||||
# IDE / editor
|
||||
.vscode/
|
||||
runs/
|
||||
|
||||
# Training & HPO outputs
|
||||
outputs/
|
||||
runs/
|
||||
smac3_output/
|
||||
training_log.txt
|
||||
.pytest_cache/
|
||||
|
||||
# Real-robot capture data (large .npz recordings)
|
||||
assets/**/recordings/
|
||||
|
||||
# MuJoCo
|
||||
MUJOCO_LOG.TXT
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Temp files
|
||||
*.stl
|
||||
*.scad
|
||||
64
README.md
Normal file
64
README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# RL-Framework
|
||||
|
||||
A small, fast RL framework for training sim2real policies on a 3D-printed
|
||||
rotary (Furuta) cartpole — built to scale from a laptop CPU to a GPU worker
|
||||
(ClearML) without code changes, and to grow into more robots and simulators.
|
||||
|
||||
## Architecture
|
||||
|
||||
Three orthogonal pieces, composed by Hydra config groups:
|
||||
|
||||
| Piece | Role | Implementations |
|
||||
|---|---|---|
|
||||
| **Env** (`src/envs/`) | Task logic: obs / reward / termination / init distribution. Pure torch, batched, backend-agnostic. | `rotary_cartpole` |
|
||||
| **Runner** (`src/runners/`) | Physics + sim2real plumbing (DR, sensor noise, action delay, history buffer). | `mujoco` (CPU), `mjx` (GPU/JAX), `serial` (real ESP32 robot) |
|
||||
| **Trainer** (`src/training/`) | skrl PPO + shared MLP with optional history encoder. | `ppo`, `ppo_mjx`, `ppo_single`, `ppo_real` |
|
||||
|
||||
The robot itself is described once in `assets/<robot>/robot.yaml`
|
||||
(URDF + identified motor model) and shared by **training, sysid and
|
||||
deployment** — the motor model (bias → deadzone → gear compensation,
|
||||
Coulomb + Stribeck friction, viscous damping, first-order lag) is
|
||||
implemented in `src/core/robot.py` and mirrored exactly in the MJX JIT
|
||||
step (`src/runners/mjx.py`).
|
||||
|
||||
## Train
|
||||
|
||||
```bash
|
||||
# CPU (64 parallel MuJoCo envs)
|
||||
python scripts/train.py env=rotary_cartpole runner=mujoco training=ppo
|
||||
|
||||
# GPU (1024 MJX envs) — local
|
||||
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx
|
||||
|
||||
# GPU — remote on ClearML gpu-queue
|
||||
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx training.remote=true
|
||||
```
|
||||
|
||||
Videos and scalars stream to ClearML. Checkpoints land in `runs/`.
|
||||
|
||||
## Sim2real recipe
|
||||
|
||||
1. **Capture** real trajectories: `python -m src.sysid.capture` (writes `.npz` to `assets/<robot>/recordings/`).
|
||||
2. **Identify** physics: `python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz`
|
||||
— CMA-ES fits inertials/joint dynamics against the recording (motor model is locked from the unified sysid). Writes `sysid_result.json` + `robot_tuned.yaml` + `*_tuned.urdf`.
|
||||
3. **Validate** the fit: `python -m src.sysid.visualize`, then copy `robot_tuned.yaml` → `robot.yaml`.
|
||||
4. **Train with DR + history**: the runner randomizes friction/damping/torque scales, sensor noise and action latency per episode (`configs/runner/mjx.yaml: domain_rand`), and appends a 10-step (obs, action) history to the observation so the policy can implicitly identify the current dynamics (`history_length`).
|
||||
5. **Deploy**: `mjpython scripts/eval.py env=rotary_cartpole runner=serial checkpoint=runs/<run>/checkpoints/agent_X.pt`
|
||||
|
||||
## Other tools
|
||||
|
||||
```bash
|
||||
mjpython scripts/viz.py env=rotary_cartpole # keyboard-drive the sim
|
||||
mjpython scripts/viz.py runner=serial # digital twin of the real robot
|
||||
python scripts/hpo.py env=rotary_cartpole training=ppo_single # ClearML + SMAC3 HPO
|
||||
pytest tests/ # unit tests
|
||||
```
|
||||
|
||||
## Adding a robot / simulator
|
||||
|
||||
- **Robot**: drop `assets/<name>/` (URDF + `robot.yaml`), subclass `BaseEnv`
|
||||
(obs/reward/termination/`initial_state_ranges`), register in `src/core/registry.py`,
|
||||
add `configs/env/<name>.yaml`.
|
||||
- **Simulator**: subclass `BaseRunner` and implement `_sim_initialize`,
|
||||
`_sim_step`, `_sim_reset` (full-batch return) — DR, history and the
|
||||
env-side logic come for free. Register in `scripts/train.py: RUNNER_REGISTRY`.
|
||||
@@ -1,64 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<robot name="cartpole">
|
||||
|
||||
<!-- World link (fixed base) -->
|
||||
<link name="world"/>
|
||||
|
||||
<!-- Cart (slides along x-axis) -->
|
||||
<link name="cart">
|
||||
<inertial>
|
||||
<mass value="1.0"/>
|
||||
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<geometry>
|
||||
<box size="0.3 0.2 0.1"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<geometry>
|
||||
<box size="0.3 0.2 0.1"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Cart slides along x-axis -->
|
||||
<joint name="cart_joint" type="prismatic">
|
||||
<parent link="world"/>
|
||||
<child link="cart"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
|
||||
</joint>
|
||||
|
||||
<!-- Pole (rotates around y-axis, attached on top of cart) -->
|
||||
<link name="pole">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<mass value="0.1"/>
|
||||
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.02" length="0.6"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.02" length="0.6"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Pole rotates freely (no motor) -->
|
||||
<joint name="pole_joint" type="revolute">
|
||||
<parent link="cart"/>
|
||||
<child link="pole"/>
|
||||
<origin xyz="0 0 0.05"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
|
||||
<dynamics damping="0.0" friction="0.0"/>
|
||||
</joint>
|
||||
|
||||
</robot>
|
||||
10
assets/motor/hardware.yaml
Normal file
10
assets/motor/hardware.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# Motor-only hardware config
|
||||
# Encoder and motor constants for the motor-only sysid capture.
|
||||
|
||||
encoder:
|
||||
ppr: 11 # pulses per revolution (before quadrature)
|
||||
gear_ratio: 30.0 # gearbox ratio
|
||||
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
|
||||
|
||||
motor:
|
||||
max_pwm: 255 # maximum PWM command accepted by firmware
|
||||
40
assets/motor/motor.xml
Normal file
40
assets/motor/motor.xml
Normal file
@@ -0,0 +1,40 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<mujoco model="motor_sysid">
|
||||
<compiler angle="radian" autolimits="true"/>
|
||||
<option timestep="0.002" integrator="Euler"/>
|
||||
|
||||
<worldbody>
|
||||
<light pos="0 0 1" dir="0 0 -1"/>
|
||||
|
||||
<!-- Fixed base (gearbox housing) -->
|
||||
<body name="base" pos="0 0 0.15">
|
||||
<geom type="cylinder" size="0.04 0.06" mass="0.921"
|
||||
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
|
||||
|
||||
<!-- Arm: rotates around motor_joint (z-axis) -->
|
||||
<body name="arm" pos="0 0 0">
|
||||
<joint name="motor_joint" type="hinge" axis="0 0 1"
|
||||
range="-1.5708 1.5708"
|
||||
damping="0.001" armature="0.0001" frictionloss="0.03"/>
|
||||
|
||||
<!-- Rotor disk: mass is tunable by the optimizer -->
|
||||
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
|
||||
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
|
||||
contype="0" conaffinity="0"/>
|
||||
|
||||
<!-- Arm load: lightweight arm attached to motor shaft -->
|
||||
<geom name="arm_load" type="capsule" size="0.004"
|
||||
fromto="0 0 0 -0.014 0.002 0.016"
|
||||
mass="0.021" rgba="0.8 0.3 0.1 1"
|
||||
contype="0" conaffinity="0"/>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<general name="motor" joint="motor_joint"
|
||||
gear="0.064"
|
||||
ctrllimited="true" ctrlrange="-1 1"
|
||||
dyntype="filter" dynprm="0.03"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
42
assets/motor/motor_bare.xml
Normal file
42
assets/motor/motor_bare.xml
Normal file
@@ -0,0 +1,42 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Motor-only model WITHOUT arm/pendulum load.
|
||||
Use for arm-off sysid captures. arm_load mass set to ~0. -->
|
||||
<mujoco model="motor_sysid_bare">
|
||||
<compiler angle="radian" autolimits="true"/>
|
||||
<option timestep="0.002" integrator="Euler"/>
|
||||
|
||||
<worldbody>
|
||||
<light pos="0 0 1" dir="0 0 -1"/>
|
||||
|
||||
<!-- Fixed base (gearbox housing) -->
|
||||
<body name="base" pos="0 0 0.15">
|
||||
<geom type="cylinder" size="0.04 0.06" mass="0.921"
|
||||
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
|
||||
|
||||
<!-- Arm: rotates around motor_joint (z-axis) -->
|
||||
<body name="arm" pos="0 0 0">
|
||||
<joint name="motor_joint" type="hinge" axis="0 0 1"
|
||||
range="-1.5708 1.5708"
|
||||
damping="0.001" armature="0.0001" frictionloss="0.03"/>
|
||||
|
||||
<!-- Rotor disk: mass is tunable by the optimizer -->
|
||||
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
|
||||
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
|
||||
contype="0" conaffinity="0"/>
|
||||
|
||||
<!-- Arm load: near-zero mass for bare-shaft sysid -->
|
||||
<geom name="arm_load" type="capsule" size="0.004"
|
||||
fromto="0 0 0 -0.014 0.002 0.016"
|
||||
mass="0.001" rgba="0.8 0.3 0.1 0.3"
|
||||
contype="0" conaffinity="0"/>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<general name="motor" joint="motor_joint"
|
||||
gear="0.064"
|
||||
ctrllimited="true" ctrlrange="-1 1"
|
||||
dyntype="filter" dynprm="0.03"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
BIN
assets/motor/motor_sysid_comparison.png
Normal file
BIN
assets/motor/motor_sysid_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 463 KiB |
67
assets/motor/motor_sysid_result.json
Normal file
67
assets/motor/motor_sysid_result.json
Normal file
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"best_params": {
|
||||
"actuator_gear_pos": 0.3711939014035462,
|
||||
"actuator_gear_neg": 0.42814281188601877,
|
||||
"actuator_filter_tau": 0.022300731564787457,
|
||||
"motor_damping_pos": 0.0013836218905629106,
|
||||
"motor_damping_neg": 0.005196351489379768,
|
||||
"motor_armature": 0.0027534181478216656,
|
||||
"motor_frictionloss_pos": 0.03674406439012955,
|
||||
"motor_frictionloss_neg": 0.06908200024786905,
|
||||
"viscous_quadratic": 0.000958226218765762,
|
||||
"back_emf_gain": 0.0036492272912788297,
|
||||
"stribeck_friction_boost": 0.044748043677129666,
|
||||
"stribeck_vel": 4.0513395945623705,
|
||||
"rotor_mass": 0.03982640764507874,
|
||||
"motor_deadzone_pos": 0.14181963932762467,
|
||||
"motor_deadzone_neg": 0.031454276545010214,
|
||||
"action_bias": -0.007362969452509152,
|
||||
"gearbox_backlash": 1.4749880999407965e-09
|
||||
},
|
||||
"best_cost": 0.21167928018839952,
|
||||
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/motor/recordings/motor_from_cartpole_162432.npz",
|
||||
"param_names": [
|
||||
"actuator_gear_pos",
|
||||
"actuator_gear_neg",
|
||||
"actuator_filter_tau",
|
||||
"motor_damping_pos",
|
||||
"motor_damping_neg",
|
||||
"motor_armature",
|
||||
"motor_frictionloss_pos",
|
||||
"motor_frictionloss_neg",
|
||||
"viscous_quadratic",
|
||||
"back_emf_gain",
|
||||
"stribeck_friction_boost",
|
||||
"stribeck_vel",
|
||||
"rotor_mass",
|
||||
"motor_deadzone_pos",
|
||||
"motor_deadzone_neg",
|
||||
"action_bias",
|
||||
"gearbox_backlash"
|
||||
],
|
||||
"defaults": {
|
||||
"actuator_gear_pos": 0.064,
|
||||
"actuator_gear_neg": 0.064,
|
||||
"actuator_filter_tau": 0.03,
|
||||
"motor_damping_pos": 0.003,
|
||||
"motor_damping_neg": 0.003,
|
||||
"motor_armature": 0.0001,
|
||||
"motor_frictionloss_pos": 0.03,
|
||||
"motor_frictionloss_neg": 0.03,
|
||||
"viscous_quadratic": 0.0,
|
||||
"back_emf_gain": 0.0,
|
||||
"stribeck_friction_boost": 0.0,
|
||||
"stribeck_vel": 2.0,
|
||||
"rotor_mass": 0.012,
|
||||
"motor_deadzone_pos": 0.08,
|
||||
"motor_deadzone_neg": 0.08,
|
||||
"action_bias": 0.0,
|
||||
"gearbox_backlash": 0.0
|
||||
},
|
||||
"timestamp": "2026-03-23T20:50:53.648753",
|
||||
"history_summary": {
|
||||
"first_cost": 5.105499579419285,
|
||||
"final_cost": 0.21167928018839952,
|
||||
"generations": 500
|
||||
}
|
||||
}
|
||||
19
assets/motor/motor_tuned.xml
Normal file
19
assets/motor/motor_tuned.xml
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version='1.0' encoding='utf-8'?>
|
||||
<mujoco model="motor_sysid">
|
||||
<compiler angle="radian" autolimits="true" />
|
||||
<option timestep="0.002" integrator="Euler" />
|
||||
<worldbody>
|
||||
<light pos="0 0 1" dir="0 0 -1" />
|
||||
<body name="base" pos="0 0 0.15">
|
||||
<geom type="cylinder" size="0.04 0.06" mass="0.921" rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0" />
|
||||
<body name="arm" pos="0 0 0">
|
||||
<joint name="motor_joint" type="hinge" axis="0 0 1" range="-1.5708 1.5708" damping="0" armature="0.0027534181478216656" frictionloss="0" />
|
||||
<geom name="rotor_disk" type="cylinder" size="0.008 0.004" mass="0.03982640764507874" pos="0 0 0" rgba="0.6 0.6 0.6 1" contype="0" conaffinity="0" />
|
||||
<geom name="arm_load" type="capsule" size="0.004" fromto="0 0 0 -0.014 0.002 0.016" mass="0.021" rgba="0.8 0.3 0.1 1" contype="0" conaffinity="0" />
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
<actuator>
|
||||
<general name="motor" joint="motor_joint" gear="0.3996683566447825" ctrllimited="true" ctrlrange="-1 1" dyntype="filter" dynprm="0.022300731564787457" />
|
||||
</actuator>
|
||||
</mujoco>
|
||||
6
assets/motor/robot.yaml
Normal file
6
assets/motor/robot.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
# Motor-only sysid config
|
||||
# Minimal config for the motor-only identification pipeline.
|
||||
# The optimizer patches motor.xml in-memory; this file tells it
|
||||
# which MJCF to load and provides encoder/hardware constants.
|
||||
|
||||
mjcf: motor.xml
|
||||
4
assets/motor/robot_bare.yaml
Normal file
4
assets/motor/robot_bare.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# Motor-only sysid config — bare shaft (no arm/pendulum load)
|
||||
# Use with arm-off captures to identify pure motor dynamics.
|
||||
|
||||
mjcf: motor_bare.xml
|
||||
23
assets/motor/robot_tuned.yaml
Normal file
23
assets/motor/robot_tuned.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
# Tuned motor config — generated by src.sysid.motor.optimize
|
||||
# Original: robot.yaml
|
||||
|
||||
mjcf: motor_tuned.xml
|
||||
joints:
|
||||
motor_joint:
|
||||
armature: 0.002753
|
||||
frictionloss: 0.052913
|
||||
hardware_realism:
|
||||
actuator_gear_pos: 0.371194
|
||||
actuator_gear_neg: 0.428143
|
||||
motor_damping_pos: 0.001384
|
||||
motor_damping_neg: 0.005196
|
||||
motor_frictionloss_pos: 0.036744
|
||||
motor_frictionloss_neg: 0.069082
|
||||
motor_deadzone_pos: 0.14182
|
||||
motor_deadzone_neg: 0.031454
|
||||
action_bias: -0.007363
|
||||
viscous_quadratic: 0.000958
|
||||
back_emf_gain: 0.003649
|
||||
stribeck_friction_boost: 0.044748
|
||||
stribeck_vel: 4.05134
|
||||
gearbox_backlash: 0.0
|
||||
23
assets/rotary_cartpole/hardware.yaml
Normal file
23
assets/rotary_cartpole/hardware.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
# Rotary cartpole (Furuta pendulum) — real hardware config.
|
||||
# Describes the physical device for the SerialRunner.
|
||||
# Robot-specific constants that don't belong in the runner config
|
||||
# (which is machine-specific: port, baud) or the env config
|
||||
# (which is task-specific: rewards, max_steps).
|
||||
|
||||
encoder:
|
||||
ppr: 11 # pulses per revolution (before quadrature)
|
||||
gear_ratio: 30.0 # gearbox ratio
|
||||
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
|
||||
|
||||
safety:
|
||||
max_motor_angle_deg: 90.0 # hard termination limit (physical endstop ~70-80°)
|
||||
soft_limit_deg: 40.0 # progressive penalty ramp starts here
|
||||
|
||||
reset:
|
||||
drive_speed: 80 # PWM magnitude for bang-bang drive-to-center
|
||||
deadband: 15 # encoder count threshold to consider "centered"
|
||||
drive_timeout: 3.0 # seconds before giving up on drive-to-center
|
||||
settle_angle_deg: 2.0 # pendulum angle threshold for "still" (degrees)
|
||||
settle_vel_dps: 5.0 # pendulum velocity threshold (deg/s)
|
||||
settle_duration: 0.5 # how long pendulum must stay still (seconds)
|
||||
settle_timeout: 30.0 # give up waiting after this (seconds)
|
||||
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
Binary file not shown.
30
assets/rotary_cartpole/robot.yaml
Normal file
30
assets/rotary_cartpole/robot.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
# Canonical training model — unified sysid (cost 0.925, 475 generations).
|
||||
# Source: sysid_result.json → exported via src.sysid.export.
|
||||
# Key physics: ~96 ms motor lag (filter_tau), Stribeck friction, driver bias.
|
||||
# Regenerate with:
|
||||
# python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz
|
||||
# then copy robot_tuned.yaml over this file once validated
|
||||
# (python -m src.sysid.visualize to compare real vs sim).
|
||||
|
||||
urdf: rotary_cartpole_tuned.urdf
|
||||
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
type: motor
|
||||
gear: [0.846499, 1.183733] # torque constant [pos, neg]
|
||||
ctrl_range: [-0.686251, 0.686251] # PWM saturation (MAX_MOTOR_SPEED / 255)
|
||||
deadzone: [0.181097, 0.202072] # L298N min |ctrl| for torque [pos, neg]
|
||||
damping: [0.013165, 0.015452] # viscous damping [pos, neg]
|
||||
frictionloss: [0.014244, 0.001005] # Coulomb friction [pos, neg]
|
||||
filter_tau: 0.096263 # 1st-order actuator lag (s) — dominant!
|
||||
stribeck_friction_boost: 0.068594 # extra static friction near standstill
|
||||
stribeck_vel: 5.279594 # Stribeck decay velocity (rad/s)
|
||||
action_bias: 0.056566 # additive ctrl bias (driver asymmetry)
|
||||
|
||||
joints:
|
||||
motor_joint:
|
||||
armature: 0.001676 # reflected rotor inertia (kg·m²)
|
||||
frictionloss: 0.0 # handled by motor model via qfrc_applied
|
||||
pendulum_joint:
|
||||
damping: 1.2e-05
|
||||
frictionloss: 7.2e-05
|
||||
34
assets/rotary_cartpole/robot_tuned.yaml
Normal file
34
assets/rotary_cartpole/robot_tuned.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
# Tuned robot config — generated by src.sysid.optimize
|
||||
# Original: robot.yaml
|
||||
# Run `python -m src.sysid.visualize` to compare real vs sim.
|
||||
|
||||
urdf: rotary_cartpole_tuned.urdf
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
type: motor
|
||||
gear:
|
||||
- 0.846499
|
||||
- 1.183733
|
||||
ctrl_range:
|
||||
- -0.686251
|
||||
- 0.686251
|
||||
deadzone:
|
||||
- 0.181097
|
||||
- 0.202072
|
||||
damping:
|
||||
- 0.013165
|
||||
- 0.015452
|
||||
frictionloss:
|
||||
- 0.014244
|
||||
- 0.001005
|
||||
filter_tau: 0.096263
|
||||
stribeck_friction_boost: 0.068594
|
||||
stribeck_vel: 5.279594
|
||||
action_bias: 0.056566
|
||||
joints:
|
||||
motor_joint:
|
||||
armature: 0.001676
|
||||
frictionloss: 0.0
|
||||
pendulum_joint:
|
||||
damping: 1.2e-05
|
||||
frictionloss: 7.2e-05
|
||||
80
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
80
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
@@ -0,0 +1,80 @@
|
||||
<?xml version='1.0' encoding='utf-8'?>
|
||||
<robot name="rotary_cartpole">
|
||||
<link name="world" />
|
||||
<link name="base_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
|
||||
<mass value="0.921" />
|
||||
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="base_joint" type="fixed">
|
||||
<parent link="world" />
|
||||
<child link="base_link" />
|
||||
</joint>
|
||||
<link name="arm">
|
||||
<inertial>
|
||||
<origin xyz="-0.0071030505291264975 0.0008511826488989179 0.007952020186701035" rpy="0 0 0" />
|
||||
<mass value="0.02110029934220782" />
|
||||
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="motor_joint" type="revolute">
|
||||
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
|
||||
<parent link="base_link" />
|
||||
<child link="arm" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
|
||||
<dynamics damping="0.001" />
|
||||
</joint>
|
||||
<link name="pendulum">
|
||||
<inertial>
|
||||
<origin xyz="0.060245187591695615 -0.07601707109312682 -0.0034636702158137786" rpy="0 0 0" />
|
||||
<mass value="0.03936742845036306" />
|
||||
<inertia ixx="6.202768755990066e-05" iyy="3.70078470430685e-05" izz="7.827356811788924e-05" ixy="-6.925117819616428e-06" iyz="0.0" ixz="0.0" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="pendulum_joint" type="continuous">
|
||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
|
||||
<parent link="arm" />
|
||||
<child link="pendulum" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<dynamics damping="0.0001" />
|
||||
</joint>
|
||||
</robot>
|
||||
80
assets/rotary_cartpole/rotary_cartpole_tuned.urdf
Normal file
80
assets/rotary_cartpole/rotary_cartpole_tuned.urdf
Normal file
@@ -0,0 +1,80 @@
|
||||
<?xml version='1.0' encoding='utf-8'?>
|
||||
<robot name="rotary_cartpole">
|
||||
<link name="world" />
|
||||
<link name="base_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
|
||||
<mass value="0.921" />
|
||||
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="base_joint" type="fixed">
|
||||
<parent link="world" />
|
||||
<child link="base_link" />
|
||||
</joint>
|
||||
<link name="arm">
|
||||
<inertial>
|
||||
<origin xyz="0.02679980831009001 -0.015110803875962989 -0.005337417994989926" rpy="0 0 0" />
|
||||
<mass value="0.04052645463607292" />
|
||||
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="motor_joint" type="revolute">
|
||||
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
|
||||
<parent link="base_link" />
|
||||
<child link="arm" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
|
||||
<dynamics damping="0.001" />
|
||||
</joint>
|
||||
<link name="pendulum">
|
||||
<inertial>
|
||||
<origin xyz="0.04237886229290564 -0.05762212306183831 -0.0006039324398328591" rpy="0 0 0" />
|
||||
<mass value="0.056793277126969105" />
|
||||
<inertia ixx="0.0005956997356690322" iyy="8.986080748758447e-05" izz="0.0006855370063143978" ixy="-1.074983574165622e-05" iyz="0.0" ixz="0.0" />
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="pendulum_joint" type="continuous">
|
||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
|
||||
<parent link="arm" />
|
||||
<child link="pendulum" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<dynamics damping="0.0001" />
|
||||
</joint>
|
||||
</robot>
|
||||
BIN
assets/rotary_cartpole/sysid_comparison.png
Normal file
BIN
assets/rotary_cartpole/sysid_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 845 KiB |
101
assets/rotary_cartpole/sysid_result.json
Normal file
101
assets/rotary_cartpole/sysid_result.json
Normal file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"best_params": {
|
||||
"actuator_gear_pos": 0.8464986357922419,
|
||||
"actuator_gear_neg": 1.1837332515116656,
|
||||
"actuator_filter_tau": 0.09626340711982824,
|
||||
"motor_damping_pos": 0.013165358841570964,
|
||||
"motor_damping_neg": 0.015451769431035585,
|
||||
"motor_armature": 0.0016762295220909746,
|
||||
"motor_frictionloss_pos": 0.014244285616762612,
|
||||
"motor_frictionloss_neg": 0.0010053574956085138,
|
||||
"stribeck_friction_boost": 0.06859381652654695,
|
||||
"stribeck_vel": 5.279593819777324,
|
||||
"motor_deadzone_pos": 0.1810971675049997,
|
||||
"motor_deadzone_neg": 0.20207167371643614,
|
||||
"action_bias": 0.05656632270757292,
|
||||
"arm_mass": 0.04052645463607292,
|
||||
"arm_com_x": 0.02679980831009001,
|
||||
"arm_com_y": -0.015110803875962989,
|
||||
"arm_com_z": -0.005337417994989926,
|
||||
"pendulum_mass": 0.056793277126969105,
|
||||
"pendulum_com_x": 0.04237886229290564,
|
||||
"pendulum_com_y": -0.05762212306183831,
|
||||
"pendulum_com_z": -0.0006039324398328591,
|
||||
"pendulum_ixx": 0.0005956997356690322,
|
||||
"pendulum_iyy": 8.986080748758447e-05,
|
||||
"pendulum_izz": 0.0006855370063143978,
|
||||
"pendulum_ixy": -1.074983574165622e-05,
|
||||
"pendulum_damping": 1.1718327130851149e-05,
|
||||
"pendulum_frictionloss": 7.204244487092445e-05,
|
||||
"ctrl_limit": 0.6862506602999546
|
||||
},
|
||||
"best_cost": 0.9250391788859942,
|
||||
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260328_153749.npz",
|
||||
"param_names": [
|
||||
"actuator_gear_pos",
|
||||
"actuator_gear_neg",
|
||||
"actuator_filter_tau",
|
||||
"motor_damping_pos",
|
||||
"motor_damping_neg",
|
||||
"motor_armature",
|
||||
"motor_frictionloss_pos",
|
||||
"motor_frictionloss_neg",
|
||||
"stribeck_friction_boost",
|
||||
"stribeck_vel",
|
||||
"motor_deadzone_pos",
|
||||
"motor_deadzone_neg",
|
||||
"action_bias",
|
||||
"arm_mass",
|
||||
"arm_com_x",
|
||||
"arm_com_y",
|
||||
"arm_com_z",
|
||||
"pendulum_mass",
|
||||
"pendulum_com_x",
|
||||
"pendulum_com_y",
|
||||
"pendulum_com_z",
|
||||
"pendulum_ixx",
|
||||
"pendulum_iyy",
|
||||
"pendulum_izz",
|
||||
"pendulum_ixy",
|
||||
"pendulum_damping",
|
||||
"pendulum_frictionloss",
|
||||
"ctrl_limit"
|
||||
],
|
||||
"defaults": {
|
||||
"actuator_gear_pos": 0.371194,
|
||||
"actuator_gear_neg": 0.428143,
|
||||
"actuator_filter_tau": 0.022301,
|
||||
"motor_damping_pos": 0.001384,
|
||||
"motor_damping_neg": 0.005196,
|
||||
"motor_armature": 0.002753,
|
||||
"motor_frictionloss_pos": 0.036744,
|
||||
"motor_frictionloss_neg": 0.069082,
|
||||
"stribeck_friction_boost": 0.0,
|
||||
"stribeck_vel": 2.0,
|
||||
"motor_deadzone_pos": 0.14182,
|
||||
"motor_deadzone_neg": 0.031454,
|
||||
"action_bias": 0.0,
|
||||
"arm_mass": 0.0211,
|
||||
"arm_com_x": -0.0071,
|
||||
"arm_com_y": 0.00085,
|
||||
"arm_com_z": 0.00795,
|
||||
"pendulum_mass": 0.03937,
|
||||
"pendulum_com_x": 0.06025,
|
||||
"pendulum_com_y": -0.07602,
|
||||
"pendulum_com_z": -0.00346,
|
||||
"pendulum_ixx": 6.2e-05,
|
||||
"pendulum_iyy": 3.7e-05,
|
||||
"pendulum_izz": 7.83e-05,
|
||||
"pendulum_ixy": -6.93e-06,
|
||||
"pendulum_damping": 0.0001,
|
||||
"pendulum_frictionloss": 0.0001,
|
||||
"ctrl_limit": 0.588
|
||||
},
|
||||
"preprocess_vel": true,
|
||||
"timestamp": "2026-03-28T17:09:39.241413",
|
||||
"history_summary": {
|
||||
"first_cost": 13.490216059926947,
|
||||
"final_cost": 0.9250391788859942,
|
||||
"generations": 475
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
defaults:
|
||||
- env: cartpole
|
||||
- env: rotary_cartpole
|
||||
- runner: mujoco
|
||||
- training: ppo
|
||||
- _self_
|
||||
- _self_
|
||||
|
||||
11
configs/env/cartpole.yaml
vendored
11
configs/env/cartpole.yaml
vendored
@@ -1,11 +0,0 @@
|
||||
max_steps: 500
|
||||
angle_threshold: 0.418
|
||||
cart_limit: 2.4
|
||||
reward_alive: 1.0
|
||||
reward_pole_upright_scale: 1.0
|
||||
reward_action_penalty_scale: 0.01
|
||||
model_path: assets/cartpole/cartpole.urdf
|
||||
actuators:
|
||||
- joint: cart_joint
|
||||
gear: 10.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
28
configs/env/rotary_cartpole.yaml
vendored
Normal file
28
configs/env/rotary_cartpole.yaml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
max_steps: 1000
|
||||
robot_path: assets/rotary_cartpole
|
||||
reward_upright_scale: 1.0
|
||||
alive_bonus: 0.25 # per-step survival bonus (living must beat dying)
|
||||
balance_bonus: 2.0 # extra reward for upright AND still (beats spinning)
|
||||
balance_vel_scale: 0.5 # how fast the balance bonus decays with pendulum speed
|
||||
|
||||
# ── Regularisation penalties (prevent fast spinning) ─────────────────
|
||||
motor_vel_penalty: 0.01 # penalise high motor angular velocity
|
||||
motor_angle_penalty: 0.05 # penalise deviation from centre
|
||||
action_penalty: 0.05 # penalise large actions (energy cost)
|
||||
action_rate_penalty: 0.01 # penalise action changes (real-motor smoothness)
|
||||
|
||||
# ── Initial state randomisation ──────────────────────────────────────
|
||||
pendulum_init_range_deg: 180.0 # pendulum starts in [-180°, +180°]
|
||||
|
||||
# ── Software safety limit (env-level, always applied) ────────────────
|
||||
motor_angle_limit_deg: 90.0 # terminate episode if motor exceeds ±90°
|
||||
|
||||
# ── HPO search ranges ────────────────────────────────────────────────
|
||||
hpo:
|
||||
reward_upright_scale: {min: 0.5, max: 5.0}
|
||||
motor_vel_penalty: {min: 0.001, max: 0.1}
|
||||
motor_angle_penalty: {min: 0.01, max: 0.2}
|
||||
action_penalty: {min: 0.01, max: 0.2}
|
||||
action_rate_penalty: {min: 0.001, max: 0.1}
|
||||
pendulum_init_range_deg: {min: 30.0, max: 180.0}
|
||||
max_steps: {values: [500, 1000, 2000]}
|
||||
16
configs/runner/mjx.yaml
Normal file
16
configs/runner/mjx.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
num_envs: 1024 # MJX shines with many parallel envs
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10 # (obs, action) window for implicit adaptation
|
||||
|
||||
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||
# Full DR on GPU: latency + sensor noise + per-env dynamics scales
|
||||
# (friction/damping/torque) are all applied inside the JIT step.
|
||||
domain_rand:
|
||||
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
|
||||
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||
@@ -1,4 +1,16 @@
|
||||
num_envs: 16
|
||||
device: cpu
|
||||
dt: 0.02
|
||||
substeps: 2
|
||||
num_envs: 64
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10 # (obs, action) window for implicit adaptation
|
||||
|
||||
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||
# Noise/delay levels anchored to the real recordings (~50 Hz, ~0.5 rad/s
|
||||
# velocity noise, ≤1-step latency). Set domain_rand: {} to disable.
|
||||
domain_rand:
|
||||
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier
|
||||
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||
|
||||
15
configs/runner/mujoco_single.yaml
Normal file
15
configs/runner/mujoco_single.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
# Single-env MuJoCo runner — mimics real hardware timing.
|
||||
# dt × substeps = 0.002 × 10 = 0.02 s → 50 Hz control, same as serial runner.
|
||||
|
||||
num_envs: 1
|
||||
device: cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10
|
||||
|
||||
# Clean by default (deterministic eval). Confirming-experiment example —
|
||||
# re-eval an existing checkpoint in sim with a fixed 1-step action delay:
|
||||
# mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
||||
# checkpoint=runs/.../agent_XXXX.pt \
|
||||
# '++runner.domain_rand.action_delay_steps=[1,1]'
|
||||
domain_rand: {}
|
||||
11
configs/runner/serial.yaml
Normal file
11
configs/runner/serial.yaml
Normal file
@@ -0,0 +1,11 @@
|
||||
# Serial runner — communicates with real hardware over USB/serial.
|
||||
# Always single-env, CPU-only. Override port on CLI:
|
||||
# python scripts/train.py runner=serial runner.port=/dev/ttyUSB0
|
||||
|
||||
num_envs: 1
|
||||
device: cpu
|
||||
port: /dev/cu.usbserial-0001
|
||||
baud: 115200
|
||||
dt: 0.02 # control loop period (50 Hz, matches training)
|
||||
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
|
||||
history_length: 10 # must match training runner
|
||||
32
configs/sysid.yaml
Normal file
32
configs/sysid.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
# System identification defaults.
|
||||
# Override via CLI: python scripts/sysid.py optimize --max-generations 50
|
||||
#
|
||||
# These are NOT Hydra config groups — the sysid scripts use argparse.
|
||||
# This file serves as documentation and can be loaded by custom wrappers.
|
||||
|
||||
capture:
|
||||
port: /dev/cu.usbserial-0001
|
||||
baud: 115200
|
||||
duration: 20.0 # seconds
|
||||
amplitude: 150 # max PWM magnitude — must match firmware MAX_MOTOR_SPEED
|
||||
hold_min_ms: 50 # PRBS min hold time
|
||||
hold_max_ms: 300 # PRBS max hold time
|
||||
dt: 0.02 # sample period (50 Hz)
|
||||
|
||||
optimize:
|
||||
sigma0: 0.3 # CMA-ES initial step size (in [0,1] normalised space)
|
||||
population_size: 50 # candidates per generation
|
||||
max_generations: 1000 # total generations (~4000 evaluations)
|
||||
sim_dt: 0.002 # MuJoCo physics timestep
|
||||
substeps: 10 # physics substeps per control step (ctrl_dt = 0.02s)
|
||||
pos_weight: 1.0 # MSE weight for angle errors
|
||||
vel_weight: 0.1 # MSE weight for velocity errors
|
||||
window_duration: 0.5 # multiple-shooting window length (s); 0 = open-loop
|
||||
seed: 42
|
||||
|
||||
# Tunable hardware-realism params (added to ROTARY_CARTPOLE_PARAMS):
|
||||
# ctrl_limit — effective motor range → exported as ctrl_range in robot.yaml
|
||||
# motor_deadzone — L298N minimum |action| for torque → exported as deadzone in robot.yaml
|
||||
# Firmware sends raw (unfiltered) sensor data; EMA filtering is
|
||||
# handled on the Python side (env transforms) and is NOT part of
|
||||
# the sysid parameter search.
|
||||
@@ -1,7 +1,10 @@
|
||||
hidden_sizes: [128, 128]
|
||||
total_timesteps: 1000000
|
||||
rollout_steps: 1024
|
||||
learning_epochs: 4
|
||||
# PPO defaults — sized for the CPU MuJoCo runner (64 parallel envs).
|
||||
# 128 rollout steps × 64 envs ≈ 8K samples per update.
|
||||
|
||||
hidden_sizes: [256, 256]
|
||||
total_timesteps: 500000 # × 64 envs = 32M env steps
|
||||
rollout_steps: 128
|
||||
learning_epochs: 5
|
||||
mini_batches: 4
|
||||
discount_factor: 0.99
|
||||
gae_lambda: 0.95
|
||||
@@ -9,5 +12,31 @@ learning_rate: 0.0003
|
||||
clip_ratio: 0.2
|
||||
value_loss_scale: 0.5
|
||||
entropy_loss_scale: 0.01
|
||||
log_interval: 10
|
||||
clearml_project: RL-Framework
|
||||
kl_threshold: 0.01 # KL-adaptive LR; 0 = fixed learning rate
|
||||
log_interval: 1000
|
||||
checkpoint_interval: 50000
|
||||
|
||||
initial_log_std: -0.5
|
||||
min_log_std: -4.0
|
||||
max_log_std: 2.0
|
||||
|
||||
record_video_every: 10000
|
||||
|
||||
# History encoder output dim — the window size itself comes from
|
||||
# runner.history_length (single source of truth).
|
||||
embedding_dim: 32
|
||||
|
||||
# ClearML remote execution (GPU worker)
|
||||
remote: false
|
||||
|
||||
# ── HPO search ranges ────────────────────────────────────────────────
|
||||
# Read by scripts/hpo.py — ignored by TrainerConfig during training.
|
||||
hpo:
|
||||
learning_rate: {min: 0.00005, max: 0.001}
|
||||
clip_ratio: {min: 0.1, max: 0.3}
|
||||
discount_factor: {min: 0.98, max: 0.999}
|
||||
gae_lambda: {min: 0.9, max: 0.99}
|
||||
entropy_loss_scale: {min: 0.0001, max: 0.1}
|
||||
value_loss_scale: {min: 0.1, max: 1.0}
|
||||
learning_epochs: {min: 2, max: 8, type: int}
|
||||
mini_batches: {values: [2, 4, 8, 16]}
|
||||
|
||||
23
configs/training/ppo_mjx.yaml
Normal file
23
configs/training/ppo_mjx.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
# PPO sized for MJX (1024+ parallel envs on GPU).
|
||||
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||
#
|
||||
# Short rollouts × many envs is the GPU-PPO sweet spot:
|
||||
# 24 steps × 1024 envs ≈ 25K samples per update (~6K per mini-batch).
|
||||
# (The old rollout_steps=2048 inherited from the CPU config meant a
|
||||
# 2M-sample memory per update — GBs of VRAM and glacial updates.)
|
||||
|
||||
defaults:
|
||||
- ppo
|
||||
- _self_
|
||||
|
||||
rollout_steps: 24
|
||||
mini_batches: 4
|
||||
learning_epochs: 5
|
||||
learning_rate: 0.0003 # KL-adaptive scheduler handles the rest
|
||||
total_timesteps: 100000 # × 1024 envs ≈ 100M env steps
|
||||
log_interval: 100
|
||||
checkpoint_interval: 10000
|
||||
|
||||
record_video_every: 10000
|
||||
|
||||
remote: false
|
||||
29
configs/training/ppo_real.yaml
Normal file
29
configs/training/ppo_real.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
# PPO tuned for single-env real-time training on real hardware.
|
||||
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||
# ~50 Hz control × 1 env = ~50 timesteps/s.
|
||||
# 100k timesteps ≈ 33 minutes of wall-clock training.
|
||||
|
||||
defaults:
|
||||
- ppo
|
||||
- _self_
|
||||
|
||||
hidden_sizes: [256, 256]
|
||||
total_timesteps: 2000000
|
||||
learning_epochs: 10
|
||||
learning_rate: 0.0005 # conservative — can't undo real-world damage
|
||||
entropy_loss_scale: 0.01
|
||||
rollout_steps: 2048
|
||||
mini_batches: 8
|
||||
log_interval: 2048
|
||||
checkpoint_interval: 5000 # frequent saves — can't rewind real hardware
|
||||
initial_log_std: -0.5 # moderate initial exploration
|
||||
min_log_std: -4.0
|
||||
max_log_std: 2.0 # cap σ at 1.0
|
||||
|
||||
# Never run real-hardware training remotely
|
||||
remote: false
|
||||
|
||||
# Tighter HPO ranges for real hardware (override base ppo.yaml ranges)
|
||||
hpo:
|
||||
entropy_loss_scale: {min: 0.00005, max: 0.001}
|
||||
learning_rate: {min: 0.0003, max: 0.003}
|
||||
25
configs/training/ppo_single.yaml
Normal file
25
configs/training/ppo_single.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
# PPO tuned for single-env simulation — mimics real hardware training.
|
||||
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||
# Same 50 Hz control (runner=mujoco_single), 1 env, conservative hypers.
|
||||
# Sim runs ~100× faster than real time, so we can afford more timesteps.
|
||||
|
||||
defaults:
|
||||
- ppo
|
||||
- _self_
|
||||
|
||||
hidden_sizes: [256, 256]
|
||||
total_timesteps: 2000000
|
||||
learning_epochs: 10
|
||||
learning_rate: 0.0003
|
||||
entropy_loss_scale: 0.01
|
||||
rollout_steps: 2048
|
||||
mini_batches: 8
|
||||
log_interval: 2048
|
||||
checkpoint_interval: 10000
|
||||
initial_log_std: -0.5
|
||||
min_log_std: -4.0
|
||||
max_log_std: 2.0
|
||||
|
||||
record_video_every: 50000
|
||||
|
||||
remote: false
|
||||
@@ -1,8 +1,21 @@
|
||||
torch
|
||||
gymnasium
|
||||
gymnasium==1.2.3
|
||||
hydra-core
|
||||
omegaconf
|
||||
mujoco
|
||||
skrl[torch]
|
||||
mujoco==3.5.0
|
||||
mujoco-mjx==3.5.0
|
||||
jax[cuda12]==0.9.1 ; sys_platform == "linux"
|
||||
jax==0.9.1 ; sys_platform != "linux"
|
||||
skrl[torch]==1.4.3
|
||||
clearml
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
structlog
|
||||
pyyaml
|
||||
pyserial
|
||||
cmaes
|
||||
matplotlib
|
||||
smac>=2.0.0
|
||||
ConfigSpace
|
||||
hpbandster
|
||||
pytest
|
||||
403
scripts/eval.py
Normal file
403
scripts/eval.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""Evaluate a trained policy on real hardware (or in simulation).
|
||||
|
||||
Loads a checkpoint and runs the policy in a closed loop. For real
|
||||
hardware the serial runner talks to the ESP32; for sim it uses the
|
||||
MuJoCo runner. A digital-twin MuJoCo viewer mirrors the robot state
|
||||
in both modes.
|
||||
|
||||
Usage (real hardware):
|
||||
mjpython scripts/eval.py env=rotary_cartpole runner=serial \
|
||||
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
||||
|
||||
Usage (simulation):
|
||||
mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
||||
checkpoint=runs/26-03-12_00-16-43-308420_PPO/checkpoints/agent_1000000.pt
|
||||
|
||||
Controls:
|
||||
Space — pause / resume policy (motor stops while paused)
|
||||
R — reset environment
|
||||
Esc — quit
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.models.mlp import SharedMLP
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_reset_flag = [False]
|
||||
_paused = [False]
|
||||
_quit_flag = [False]
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo viewer on key press."""
|
||||
if keycode == 32: # GLFW_KEY_SPACE
|
||||
_paused[0] = not _paused[0]
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
elif keycode == 256: # GLFW_KEY_ESCAPE
|
||||
_quit_flag[0] = True
|
||||
|
||||
|
||||
# ── checkpoint loading ───────────────────────────────────────────────
|
||||
|
||||
def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
||||
"""Infer hidden layer sizes from a SharedMLP state dict."""
|
||||
sizes = []
|
||||
i = 0
|
||||
while f"net.{i}.weight" in state_dict:
|
||||
sizes.append(state_dict[f"net.{i}.weight"].shape[0])
|
||||
i += 2 # skip activation layers (ELU)
|
||||
return tuple(sizes)
|
||||
|
||||
|
||||
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
|
||||
"""Return the history encoder output dim, if present.
|
||||
|
||||
Lets eval reconstruct an embedding policy without knowing the training
|
||||
embedding_dim — read it straight from the saved weights.
|
||||
"""
|
||||
if "history_encoder.fc.weight" in state_dict:
|
||||
return state_dict["history_encoder.fc.weight"].shape[0]
|
||||
return None
|
||||
|
||||
|
||||
def load_policy(
|
||||
checkpoint_path: str,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
history_length: int = 0,
|
||||
raw_obs_dim: int = 0,
|
||||
) -> tuple[SharedMLP, RunningStandardScaler]:
|
||||
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
||||
|
||||
For DR + history-embedding policies (history_length > 0), the history
|
||||
encoder is reconstructed too — its output dim is read back from the
|
||||
saved weights.
|
||||
|
||||
Returns:
|
||||
(model, state_preprocessor) ready for inference.
|
||||
"""
|
||||
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||
|
||||
# Infer architecture from saved weights.
|
||||
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
||||
enc_out = _infer_encoder_out_dim(ckpt["policy"])
|
||||
|
||||
# Reconstruct model — pass through the encoder config so a DR+embedding
|
||||
# checkpoint rebuilds the history encoder with matching dimensions.
|
||||
model = SharedMLP(
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
history_length=history_length if enc_out else 0,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=enc_out or 32,
|
||||
)
|
||||
model.load_state_dict(ckpt["policy"])
|
||||
model.eval()
|
||||
|
||||
# Reconstruct observation normalizer.
|
||||
state_preprocessor = RunningStandardScaler(size=observation_space, device=device)
|
||||
state_preprocessor.running_mean = ckpt["state_preprocessor"]["running_mean"].to(device)
|
||||
state_preprocessor.running_variance = ckpt["state_preprocessor"]["running_variance"].to(device)
|
||||
state_preprocessor.current_count = ckpt["state_preprocessor"]["current_count"]
|
||||
# Freeze the normalizer — don't update stats during eval.
|
||||
state_preprocessor.training = False
|
||||
|
||||
logger.info(
|
||||
"checkpoint_loaded",
|
||||
path=checkpoint_path,
|
||||
hidden_sizes=hidden_sizes,
|
||||
obs_mean=[round(x, 3) for x in state_preprocessor.running_mean.tolist()],
|
||||
obs_std=[round(x, 3) for x in state_preprocessor.running_variance.sqrt().tolist()],
|
||||
)
|
||||
return model, state_preprocessor
|
||||
|
||||
|
||||
# ── action arrow overlay ─────────────────────────────────────────────
|
||||
|
||||
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
"""Draw an arrow showing applied torque direction."""
|
||||
if abs(action_val) < 0.01 or model.nu == 0:
|
||||
return
|
||||
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
body_id = model.jnt_bodyid[jnt_id]
|
||||
pos = data.xpos[body_id].copy()
|
||||
pos[2] += 0.02
|
||||
|
||||
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||
arrow_len = 0.08 * action_val
|
||||
direction = axis * np.sign(arrow_len)
|
||||
|
||||
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||
x = np.cross(up, z)
|
||||
x /= np.linalg.norm(x) + 1e-8
|
||||
y = np.cross(z, x)
|
||||
mat = np.column_stack([x, y, z]).flatten()
|
||||
|
||||
rgba = np.array(
|
||||
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||
pos=pos,
|
||||
mat=mat,
|
||||
rgba=rgba,
|
||||
)
|
||||
viewer.user_scn.ngeom += 1
|
||||
|
||||
|
||||
# ── main loops ───────────────────────────────────────────────────────
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco_single")
|
||||
|
||||
checkpoint_path = cfg.get("checkpoint", None)
|
||||
if checkpoint_path is None:
|
||||
logger.error("No checkpoint specified. Use: +checkpoint=path/to/agent.pt")
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve relative paths against original working directory.
|
||||
checkpoint_path = str(Path(hydra.utils.get_original_cwd()) / checkpoint_path)
|
||||
if not Path(checkpoint_path).exists():
|
||||
logger.error("checkpoint_not_found", path=checkpoint_path)
|
||||
sys.exit(1)
|
||||
|
||||
if runner_name == "serial":
|
||||
_eval_serial(cfg, env_name, checkpoint_path)
|
||||
else:
|
||||
_eval_sim(cfg, env_name, checkpoint_path)
|
||||
|
||||
|
||||
def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"""Evaluate policy in MuJoCo simulation with viewer."""
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
device = runner.device
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, runner.observation_space, runner.action_space, device,
|
||||
history_length=runner.config.history_length,
|
||||
raw_obs_dim=runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
mj_model = runner._model
|
||||
mj_data = runner._data[0]
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode = 0
|
||||
episode_reward = 0.0
|
||||
|
||||
logger.info(
|
||||
"eval_started",
|
||||
env=env_name,
|
||||
mode="simulation",
|
||||
checkpoint=Path(checkpoint_path).name,
|
||||
controls="Space=pause, R=reset, Esc=quit",
|
||||
)
|
||||
|
||||
while viewer.is_running() and not _quit_flag[0]:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
logger.info("reset", episode=episode)
|
||||
|
||||
if _paused[0]:
|
||||
viewer.sync()
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
|
||||
# Policy inference
|
||||
with torch.no_grad():
|
||||
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
||||
action = model.act({"states": normalized_obs}, role="policy")[0]
|
||||
action = action.clamp(-1.0, 1.0)
|
||||
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
episode_reward += reward.item()
|
||||
step += 1
|
||||
|
||||
# Sync viewer
|
||||
mujoco.mj_forward(mj_model, mj_data)
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
||||
viewer.sync()
|
||||
|
||||
if step % 50 == 0:
|
||||
joints = {mj_model.jnt(i).name: round(math.degrees(mj_data.qpos[i]), 1)
|
||||
for i in range(mj_model.njnt)}
|
||||
logger.debug(
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1), **joints,
|
||||
)
|
||||
|
||||
if terminated.any() or truncated.any():
|
||||
logger.info(
|
||||
"episode_done", episode=episode, steps=step,
|
||||
total_reward=round(episode_reward, 2),
|
||||
reason="terminated" if terminated.any() else "truncated",
|
||||
)
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
|
||||
time.sleep(dt_ctrl)
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"""Evaluate policy on real hardware via serial, with digital-twin viewer."""
|
||||
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
serial_runner = SerialRunner(env=env, config=SerialRunnerConfig(**runner_dict))
|
||||
|
||||
device = serial_runner.device
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
|
||||
history_length=serial_runner.config.history_length,
|
||||
raw_obs_dim=serial_runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
# Set up digital-twin MuJoCo model for visualization.
|
||||
serial_runner._ensure_viz_model()
|
||||
mj_model = serial_runner._viz_model
|
||||
mj_data = serial_runner._viz_data
|
||||
|
||||
with mujoco.viewer.launch_passive(mj_model, mj_data, key_callback=_key_callback) as viewer:
|
||||
obs, _ = serial_runner.reset()
|
||||
step = 0
|
||||
episode = 0
|
||||
episode_reward = 0.0
|
||||
|
||||
logger.info(
|
||||
"eval_started",
|
||||
env=env_name,
|
||||
mode="real hardware (serial)",
|
||||
port=serial_runner.config.port,
|
||||
checkpoint=Path(checkpoint_path).name,
|
||||
controls="Space=pause, R=reset, Esc=quit",
|
||||
)
|
||||
|
||||
while viewer.is_running() and not _quit_flag[0]:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
obs, _ = serial_runner.reset() # drives to center + settles
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
logger.info("reset", episode=episode)
|
||||
|
||||
if _paused[0]:
|
||||
serial_runner._send("M0") # safety: stop motor while paused
|
||||
serial_runner._sync_viz()
|
||||
viewer.sync()
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
|
||||
# Policy inference
|
||||
with torch.no_grad():
|
||||
normalized_obs = preprocessor(obs.unsqueeze(0) if obs.dim() == 1 else obs)
|
||||
action = model.act({"states": normalized_obs}, role="policy")[0]
|
||||
action = action.clamp(-1.0, 1.0)
|
||||
|
||||
obs, reward, terminated, truncated, info = serial_runner.step(action)
|
||||
episode_reward += reward.item()
|
||||
step += 1
|
||||
|
||||
# Sync digital twin with real sensor data.
|
||||
serial_runner._sync_viz()
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, mj_model, mj_data, action[0, 0].item())
|
||||
viewer.sync()
|
||||
|
||||
if step % 25 == 0:
|
||||
state = serial_runner._read_state()
|
||||
logger.debug(
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1),
|
||||
motor_deg=round(math.degrees(state["motor_rad"]), 1),
|
||||
pend_deg=round(math.degrees(state["pend_rad"]), 1),
|
||||
)
|
||||
|
||||
# Check for safety / disconnection.
|
||||
if info.get("reboot_detected") or info.get("motor_limit_exceeded"):
|
||||
logger.error(
|
||||
"safety_stop",
|
||||
reboot=info.get("reboot_detected", False),
|
||||
motor_limit=info.get("motor_limit_exceeded", False),
|
||||
)
|
||||
serial_runner._send("M0")
|
||||
break
|
||||
|
||||
if terminated.any() or truncated.any():
|
||||
logger.info(
|
||||
"episode_done", episode=episode, steps=step,
|
||||
total_reward=round(episode_reward, 2),
|
||||
reason="terminated" if terminated.any() else "truncated",
|
||||
)
|
||||
# Auto-reset for next episode.
|
||||
obs, _ = serial_runner.reset()
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
|
||||
# Real-time pacing is handled by serial_runner.step() (dt sleep).
|
||||
|
||||
serial_runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
442
scripts/hpo.py
Normal file
442
scripts/hpo.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""Hyperparameter optimization for RL-Framework using ClearML + SMAC3.
|
||||
|
||||
Automatically creates a base training task (via Task.create), reads HPO
|
||||
search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
|
||||
and launches SMAC3 Successive Halving optimization.
|
||||
|
||||
Usage:
|
||||
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single
|
||||
|
||||
# With HPO-specific options:
|
||||
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single \\
|
||||
--queue gpu-queue --total-trials 100
|
||||
|
||||
# Or use an existing base task:
|
||||
python scripts/hpo.py --base-task-id <TASK_ID>
|
||||
|
||||
# Dry run (print search space only):
|
||||
python scripts/hpo.py env=rotary_cartpole --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from clearml.automation import (
|
||||
DiscreteParameterRange,
|
||||
HyperParameterOptimizer,
|
||||
UniformIntegerParameterRange,
|
||||
UniformParameterRange,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def _load_hydra_config(
|
||||
env: str, runner: str, training: str
|
||||
) -> dict:
|
||||
"""Load and merge Hydra configs to extract HPO ranges.
|
||||
|
||||
We read the YAML files directly (without running Hydra) so this script
|
||||
doesn't need @hydra.main — it's a ClearML optimizer, not a training job.
|
||||
"""
|
||||
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||
|
||||
# Load training config (handles defaults: [ppo] inheritance)
|
||||
training_path = configs_dir / "training" / f"{training}.yaml"
|
||||
training_cfg = OmegaConf.load(training_path)
|
||||
|
||||
# If the training config has defaults pointing to a base, load + merge
|
||||
if "defaults" in training_cfg:
|
||||
defaults = OmegaConf.to_container(training_cfg.defaults)
|
||||
base_cfg = OmegaConf.create({})
|
||||
for d in defaults:
|
||||
if isinstance(d, str):
|
||||
base_path = configs_dir / "training" / f"{d}.yaml"
|
||||
if base_path.exists():
|
||||
loaded = OmegaConf.load(base_path)
|
||||
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||
# Remove defaults key and merge
|
||||
training_no_defaults = {
|
||||
k: v for k, v in OmegaConf.to_container(training_cfg).items()
|
||||
if k != "defaults"
|
||||
}
|
||||
training_cfg = OmegaConf.merge(base_cfg, OmegaConf.create(training_no_defaults))
|
||||
|
||||
# Load env config
|
||||
env_path = configs_dir / "env" / f"{env}.yaml"
|
||||
env_cfg = OmegaConf.load(env_path) if env_path.exists() else OmegaConf.create({})
|
||||
|
||||
return {
|
||||
"training": OmegaConf.to_container(training_cfg, resolve=True),
|
||||
"env": OmegaConf.to_container(env_cfg, resolve=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_hyper_parameters(config: dict) -> list:
|
||||
"""Build ClearML parameter ranges from hpo: blocks in config.
|
||||
|
||||
Reads training.hpo and env.hpo dicts and creates appropriate
|
||||
ClearML parameter range objects.
|
||||
|
||||
Each hpo entry can have:
|
||||
{min, max} → UniformParameterRange (float)
|
||||
{min, max, type: int} → UniformIntegerParameterRange
|
||||
{min, max, log: true} → UniformParameterRange with log scale
|
||||
{values: [...]} → DiscreteParameterRange
|
||||
"""
|
||||
params = []
|
||||
|
||||
for section in ("training", "env"):
|
||||
hpo_ranges = config.get(section, {}).get("hpo", {})
|
||||
if not hpo_ranges:
|
||||
continue
|
||||
|
||||
for param_name, spec in hpo_ranges.items():
|
||||
hydra_key = f"Hydra/{section}.{param_name}"
|
||||
|
||||
if "values" in spec:
|
||||
params.append(
|
||||
DiscreteParameterRange(hydra_key, values=spec["values"])
|
||||
)
|
||||
elif "min" in spec and "max" in spec:
|
||||
if spec.get("type") == "int":
|
||||
params.append(
|
||||
UniformIntegerParameterRange(
|
||||
hydra_key,
|
||||
min_value=int(spec["min"]),
|
||||
max_value=int(spec["max"]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
step = spec.get("step", None)
|
||||
params.append(
|
||||
UniformParameterRange(
|
||||
hydra_key,
|
||||
min_value=float(spec["min"]),
|
||||
max_value=float(spec["max"]),
|
||||
step_size=step,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning("skipping_unknown_hpo_spec", param=param_name, spec=spec)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
|
||||
"""Flatten a nested dict into dot-separated keys.
|
||||
|
||||
Example: {"a": {"b": 1}} → {"a.b": 1}
|
||||
"""
|
||||
items = {}
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.update(_flatten_dict(v, new_key, sep=sep))
|
||||
else:
|
||||
items[new_key] = v
|
||||
return items
|
||||
|
||||
|
||||
def _create_base_task(
|
||||
env: str, runner: str, training: str, queue: str
|
||||
) -> str:
|
||||
"""Create a base ClearML task without executing it.
|
||||
|
||||
Uses Task.create() to register a task pointing at scripts/train.py
|
||||
with the correct Hydra overrides. The HPO optimizer will clone this.
|
||||
The full resolved OmegaConf config is attached as Hydra/* parameters
|
||||
so cloned trial tasks inherit the complete configuration.
|
||||
"""
|
||||
script_path = str(Path(__file__).resolve().parent / "train.py")
|
||||
project_root = str(Path(__file__).resolve().parent.parent)
|
||||
|
||||
base_task = Task.create(
|
||||
project_name="RL-Framework",
|
||||
task_name=f"{env}-{runner}-{training} (HPO base)",
|
||||
task_type=Task.TaskTypes.training,
|
||||
script=script_path,
|
||||
working_directory=project_root,
|
||||
argparse_args=[
|
||||
f"env={env}",
|
||||
f"runner={runner}",
|
||||
f"training={training}",
|
||||
],
|
||||
add_task_init_call=False,
|
||||
)
|
||||
|
||||
# ── Attach full resolved OmegaConf config ─────────────────────
|
||||
# ClearML's Hydra binding normally does this when the script runs,
|
||||
# but Task.create() never executes Hydra. We replicate the binding
|
||||
# manually: config group choices + all resolved values.
|
||||
base_task.set_parameter("Hydra/env", env)
|
||||
base_task.set_parameter("Hydra/runner", runner)
|
||||
base_task.set_parameter("Hydra/training", training)
|
||||
|
||||
# Load and resolve the full config for each group
|
||||
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||
for section, name in [("training", training), ("env", env), ("runner", runner)]:
|
||||
cfg_path = configs_dir / section / f"{name}.yaml"
|
||||
if not cfg_path.exists():
|
||||
continue
|
||||
cfg = OmegaConf.load(cfg_path)
|
||||
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
|
||||
if "defaults" in cfg:
|
||||
defaults = OmegaConf.to_container(cfg.defaults)
|
||||
base_cfg = OmegaConf.create({})
|
||||
for d in defaults:
|
||||
if isinstance(d, str):
|
||||
base_path = configs_dir / section / f"{d}.yaml"
|
||||
if base_path.exists():
|
||||
loaded = OmegaConf.load(base_path)
|
||||
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||
cfg_no_defaults = {
|
||||
k: v for k, v in OmegaConf.to_container(cfg).items()
|
||||
if k != "defaults"
|
||||
}
|
||||
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
|
||||
|
||||
resolved = OmegaConf.to_container(cfg, resolve=True)
|
||||
# Remove hpo metadata — not a real config value
|
||||
resolved.pop("hpo", None)
|
||||
flat = _flatten_dict(resolved)
|
||||
for key, value in flat.items():
|
||||
base_task.set_parameter(f"Hydra/{section}.{key}", value)
|
||||
|
||||
# Set docker config
|
||||
base_task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=osmesa",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
base_task.set_packages(str(req_file))
|
||||
|
||||
task_id = base_task.id
|
||||
logger.info("base_task_created", task_id=task_id, task_name=base_task.name)
|
||||
return task_id
|
||||
|
||||
|
||||
def _parse_overrides(argv: list[str]) -> dict[str, str]:
|
||||
"""Parse Hydra-style key=value overrides from argv.
|
||||
|
||||
Returns a dict of parsed key-value pairs. Unknown args (--flags)
|
||||
are left in argv for argparse to handle.
|
||||
"""
|
||||
overrides = {}
|
||||
remaining = []
|
||||
for arg in argv:
|
||||
if "=" in arg and not arg.startswith("-"):
|
||||
key, value = arg.split("=", 1)
|
||||
overrides[key] = value
|
||||
else:
|
||||
remaining.append(arg)
|
||||
argv.clear()
|
||||
argv.extend(remaining)
|
||||
return overrides
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# First pass: extract Hydra-style key=value overrides from sys.argv
|
||||
raw_args = sys.argv[1:]
|
||||
overrides = _parse_overrides(raw_args)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Hyperparameter optimization for RL-Framework",
|
||||
usage="%(prog)s env=<ENV> runner=<RUNNER> training=<TRAINING> [options]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-task-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Existing ClearML task ID to use as base (skip auto-creation)",
|
||||
)
|
||||
parser.add_argument("--queue", type=str, default="gpu-queue")
|
||||
parser.add_argument(
|
||||
"--max-concurrent", type=int, default=2,
|
||||
help="Maximum concurrent trial tasks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--total-trials", type=int, default=200,
|
||||
help="Total HPO trial budget",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-budget", type=int, default=50_000,
|
||||
help="Minimum budget (total_timesteps) per trial",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-budget", type=int, default=500_000,
|
||||
help="Maximum budget (total_timesteps) for promoted trials",
|
||||
)
|
||||
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
||||
parser.add_argument(
|
||||
"--max-consecutive-failures", type=int, default=3,
|
||||
help="Abort HPO after N consecutive trial failures (0 = never abort)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time-limit-hours", type=float, default=72,
|
||||
help="Total wall-clock time limit in hours",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--objective-metric", type=str, default="Reward / Total reward (mean)",
|
||||
help="ClearML scalar metric title to optimize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--objective-series", type=str, default=None,
|
||||
help="ClearML scalar metric series (default: same as title)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--maximize", action="store_true", default=True,
|
||||
help="Maximize the objective (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minimize", action="store_true", default=False,
|
||||
help="Minimize the objective",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="Print search space and exit without running",
|
||||
)
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
# Resolve env/runner/training from Hydra-style overrides (same as train.py)
|
||||
env = overrides.get("env", "rotary_cartpole")
|
||||
runner = overrides.get("runner", "mujoco_single")
|
||||
training = overrides.get("training", "ppo_single")
|
||||
|
||||
objective_sign = "min" if args.minimize else "max"
|
||||
|
||||
# ── Load config and build search space ────────────────────────
|
||||
config = _load_hydra_config(env, runner, training)
|
||||
hyper_parameters = _build_hyper_parameters(config)
|
||||
|
||||
if not hyper_parameters:
|
||||
logger.error(
|
||||
"no_hpo_ranges_found",
|
||||
hint="Add 'hpo:' blocks to your training and/or env YAML configs",
|
||||
)
|
||||
return
|
||||
|
||||
if args.dry_run:
|
||||
print(f"\nSearch space ({len(hyper_parameters)} parameters):")
|
||||
for p in hyper_parameters:
|
||||
print(f" {p.name}: {p}")
|
||||
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
|
||||
return
|
||||
|
||||
# ── Initialize ClearML HPO task ───────────────────────────────
|
||||
Task.ignore_requirements("torch")
|
||||
task = Task.init(
|
||||
project_name="RL-Framework",
|
||||
task_name=f"HPO {env}-{runner}-{training}",
|
||||
task_type=Task.TaskTypes.optimizer,
|
||||
reuse_last_task_id=False,
|
||||
)
|
||||
task.set_base_docker(
|
||||
docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||
docker_arguments=[
|
||||
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
||||
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
||||
"-e", "CLEARML_AGENT_FORCE_SYSTEM_SITE_PACKAGES=1",
|
||||
],
|
||||
)
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
# Store the base_task_id on the HPO task so that when the services
|
||||
# worker re-runs this script it reuses the same base task instead
|
||||
# of creating a duplicate.
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
existing = task.get_parameter("General/base_task_id")
|
||||
if existing:
|
||||
base_task_id = existing
|
||||
logger.info("reusing_base_task_from_param", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
env, runner, training, args.queue
|
||||
)
|
||||
task.set_parameter("General/base_task_id", base_task_id)
|
||||
|
||||
# ── Build objective metric ────────────────────────────────────
|
||||
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
|
||||
objective_title = args.objective_metric
|
||||
objective_series = args.objective_series or objective_title
|
||||
|
||||
# ── Launch optimizer ──────────────────────────────────────────
|
||||
from src.hpo.smac3 import OptimizerSMAC
|
||||
|
||||
optimizer = HyperParameterOptimizer(
|
||||
base_task_id=base_task_id,
|
||||
hyper_parameters=hyper_parameters,
|
||||
objective_metric_title=objective_title,
|
||||
objective_metric_series=objective_series,
|
||||
objective_metric_sign=objective_sign,
|
||||
optimizer_class=OptimizerSMAC,
|
||||
execution_queue=args.queue,
|
||||
max_number_of_concurrent_tasks=args.max_concurrent,
|
||||
total_max_jobs=args.total_trials,
|
||||
min_iteration_per_job=args.min_budget,
|
||||
max_iteration_per_job=args.max_budget,
|
||||
pool_period_min=1,
|
||||
time_limit_per_job=240, # 4 hours per trial max
|
||||
eta=args.eta,
|
||||
budget_param_name="Hydra/training.total_timesteps",
|
||||
max_consecutive_failures=args.max_consecutive_failures,
|
||||
)
|
||||
|
||||
# Send this HPO controller to a remote services worker
|
||||
task.execute_remotely(queue_name="services", exit_process=True)
|
||||
|
||||
# Reporting and time limits
|
||||
optimizer.set_report_period(1)
|
||||
optimizer.set_time_limit(in_minutes=int(args.time_limit_hours * 60))
|
||||
|
||||
# Start and wait
|
||||
optimizer.start()
|
||||
optimizer.wait()
|
||||
|
||||
# Get top experiments
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
top_exp = optimizer.get_top_experiments(top_k=10)
|
||||
logger.info("top_experiments_retrieved", count=len(top_exp))
|
||||
for i, t in enumerate(top_exp):
|
||||
logger.info("top_experiment", rank=i + 1, task_id=t.id, name=t.name)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("retry_get_top_experiments", attempt=attempt + 1, error=str(e))
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(5.0 * (2 ** attempt))
|
||||
else:
|
||||
logger.error("could_not_retrieve_top_experiments")
|
||||
|
||||
optimizer.stop()
|
||||
logger.info("hpo_complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
scripts/sysid.py
Normal file
57
scripts/sysid.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Unified CLI for system identification tools.
|
||||
|
||||
Usage:
|
||||
python scripts/sysid.py capture --robot-path assets/rotary_cartpole --duration 20
|
||||
python scripts/sysid.py optimize --robot-path assets/rotary_cartpole --recording <file>.npz
|
||||
python scripts/sysid.py visualize --recording <file>.npz
|
||||
python scripts/sysid.py export --robot-path assets/rotary_cartpole --result <result>.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||
print(
|
||||
"Usage: python scripts/sysid.py <command> [options]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" capture Record real robot trajectory under PRBS excitation\n"
|
||||
" optimize Run CMA-ES parameter optimization\n"
|
||||
" visualize Plot real vs simulated trajectories\n"
|
||||
" export Write tuned URDF + robot.yaml files\n"
|
||||
"\n"
|
||||
"Run 'python scripts/sysid.py <command> --help' for command-specific options."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
command = sys.argv[1]
|
||||
# Remove the subcommand from argv so the module's argparse works normally
|
||||
sys.argv = [f"sysid {command}"] + sys.argv[2:]
|
||||
|
||||
if command == "capture":
|
||||
from src.sysid.capture import main as cmd_main
|
||||
elif command == "optimize":
|
||||
from src.sysid.optimize import main as cmd_main
|
||||
elif command == "visualize":
|
||||
from src.sysid.visualize import main as cmd_main
|
||||
elif command == "export":
|
||||
from src.sysid.export import main as cmd_main
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("Available commands: capture, optimize, visualize, export")
|
||||
sys.exit(1)
|
||||
|
||||
cmd_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
130
scripts/train.py
Normal file
130
scripts/train.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
# Ensure project root is on sys.path so `src.*` imports work
|
||||
# regardless of which directory the script is invoked from.
|
||||
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
# Headless rendering on Linux servers (must be set before mujoco import).
|
||||
# EGL renders on the GPU directly (right for NVIDIA nodes) and avoids the
|
||||
# brittle OSMesa/PyOpenGL stack. Forced (not setdefault) so a stale
|
||||
# `-e MUJOCO_GL=osmesa` baked into a remote task can't override it.
|
||||
if sys.platform == "linux":
|
||||
os.environ["MUJOCO_GL"] = "egl"
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.registry import build_env
|
||||
from src.core.runner import BaseRunner
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||
|
||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||
}
|
||||
|
||||
|
||||
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||
"""Instantiate the right runner from the Hydra config-group name."""
|
||||
if runner_name not in RUNNER_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||
)
|
||||
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
runner_cls = getattr(mod, cls_name)
|
||||
config_cls = getattr(mod, cfg_cls_name)
|
||||
|
||||
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
return runner_cls(env=env, config=runner_config)
|
||||
|
||||
|
||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags."""
|
||||
Task.ignore_requirements("torch")
|
||||
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
training_name = choices.get("training", "ppo")
|
||||
|
||||
project = "RL-Framework"
|
||||
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker(
|
||||
"git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libegl1 libgl1 libglfw3 libosmesa6 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=egl",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# Execute remotely if requested and running locally
|
||||
if remote and task.running_locally():
|
||||
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
|
||||
# ClearML init — must happen before heavy work so remote execution
|
||||
# can take over early. The remote worker re-runs the full script;
|
||||
# execute_remotely() is a no-op on the worker side.
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
remote = training_dict.pop("remote", False)
|
||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
|
||||
# resume_from_task_id or any future additions)
|
||||
import dataclasses as _dc
|
||||
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
task.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
255
scripts/viz.py
Normal file
255
scripts/viz.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||
|
||||
Usage (simulation):
|
||||
mjpython scripts/viz.py env=rotary_cartpole
|
||||
mjpython scripts/viz.py env=rotary_cartpole +com=true
|
||||
|
||||
Usage (real hardware — digital twin):
|
||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
||||
|
||||
Controls:
|
||||
Left/Right arrows — apply torque to first actuator
|
||||
R — reset environment
|
||||
Esc / close window — quit
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
_action_time = [0.0] # timestamp of last key press
|
||||
_reset_flag = [False]
|
||||
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||
if keycode == 263: # GLFW_KEY_LEFT
|
||||
_action_val[0] = -1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||
_action_val[0] = 1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
|
||||
|
||||
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
"""Draw an arrow on the motor joint showing applied torque direction."""
|
||||
if abs(action_val) < 0.01 or model.nu == 0:
|
||||
return
|
||||
|
||||
# Get the body that the first actuator's joint belongs to
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
body_id = model.jnt_bodyid[jnt_id]
|
||||
|
||||
# Arrow origin: body position
|
||||
pos = data.xpos[body_id].copy()
|
||||
pos[2] += 0.02 # lift slightly above the body
|
||||
|
||||
# Arrow direction: along joint axis in world frame, scaled by action
|
||||
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||
arrow_len = 0.08 * action_val
|
||||
direction = axis * np.sign(arrow_len)
|
||||
|
||||
# Build rotation matrix: arrow rendered along local z-axis
|
||||
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||
x = np.cross(up, z)
|
||||
x /= np.linalg.norm(x) + 1e-8
|
||||
y = np.cross(z, x)
|
||||
mat = np.column_stack([x, y, z]).flatten()
|
||||
|
||||
# Color: green = positive, red = negative
|
||||
rgba = np.array(
|
||||
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||
pos=pos,
|
||||
mat=mat,
|
||||
rgba=rgba,
|
||||
)
|
||||
viewer.user_scn.ngeom += 1
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
|
||||
if runner_name == "serial":
|
||||
_main_serial(cfg, env_name)
|
||||
else:
|
||||
_main_sim(cfg, env_name)
|
||||
|
||||
|
||||
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
model = runner._model
|
||||
data = runner._data[0]
|
||||
|
||||
# Control period
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
# Launch viewer
|
||||
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
|
||||
logger.info("viewer_started", env=env_name,
|
||||
controls="Left/Right arrows = torque, R = reset")
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
logger.info("reset")
|
||||
|
||||
# Step through runner
|
||||
action = torch.tensor([[action_val]])
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
|
||||
# Sync viewer with action arrow overlay
|
||||
mujoco.mj_forward(model, data)
|
||||
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Print state
|
||||
if step % 25 == 0:
|
||||
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||
for i in range(model.njnt)}
|
||||
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action_val, 1), **joints)
|
||||
|
||||
# Real-time pacing
|
||||
time.sleep(dt_ctrl)
|
||||
step += 1
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
||||
|
||||
The MuJoCo model is loaded for rendering only. Joint positions are
|
||||
read from the ESP32 over serial and applied to the model each frame.
|
||||
Keyboard arrows send motor commands to the real robot.
|
||||
"""
|
||||
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
serial_runner = SerialRunner(
|
||||
env=env, config=SerialRunnerConfig(**runner_dict)
|
||||
)
|
||||
|
||||
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
||||
serial_runner._ensure_viz_model()
|
||||
model = serial_runner._viz_model
|
||||
data = serial_runner._viz_data
|
||||
|
||||
with mujoco.viewer.launch_passive(
|
||||
model, data, key_callback=_key_callback
|
||||
) as viewer:
|
||||
# Show CoM / inertia if requested.
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
logger.info(
|
||||
"viewer_started",
|
||||
env=env_name,
|
||||
mode="serial (digital twin)",
|
||||
port=serial_runner.config.port,
|
||||
controls="Left/Right arrows = motor command, R = reset",
|
||||
)
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from keyboard callback.
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press.
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_settle()
|
||||
logger.info("reset (drive-to-center + settle)")
|
||||
|
||||
# Send motor command to real hardware (same PWM scaling as
|
||||
# the policy path: ctrl_range-limited).
|
||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
|
||||
serial_runner._send(f"M{motor_speed}")
|
||||
|
||||
# Sync MuJoCo model with real sensor data.
|
||||
serial_runner._sync_viz()
|
||||
|
||||
# Render overlays and sync viewer.
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Real-time pacing (~50 Hz, matches serial dt).
|
||||
time.sleep(serial_runner.config.dt)
|
||||
|
||||
serial_runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,33 +1,25 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
from typing import TypeVar, Generic, Any
|
||||
from gymnasium import spaces
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pathlib
|
||||
from gymnasium import spaces
|
||||
|
||||
from src.core.robot import RobotConfig, load_robot_config
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorConfig:
|
||||
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
|
||||
Kept in the env config (not runner config) because actuators define what the robot
|
||||
can do, which determines action space — a task-level concept.
|
||||
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
|
||||
joint: str = ""
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseEnvConfig:
|
||||
max_steps: int = 1000
|
||||
model_path: pathlib.Path | None = None
|
||||
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
|
||||
robot_path: str = "" # directory containing robot.yaml + URDF
|
||||
|
||||
class BaseEnv(abc.ABC, Generic[T]):
|
||||
def __init__(self, config: BaseEnvConfig):
|
||||
self.config = config
|
||||
self.robot: RobotConfig = load_robot_config(config.robot_path)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -48,7 +40,9 @@ class BaseEnv(abc.ABC, Generic[T]):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor:
|
||||
def compute_rewards(
|
||||
self, state: Any, actions: torch.Tensor, prev_actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -57,3 +51,26 @@ class BaseEnv(abc.ABC, Generic[T]):
|
||||
|
||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||
return step_counts >= self.config.max_steps
|
||||
|
||||
def initial_state_ranges(
|
||||
self, nq: int, nv: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Per-DOF uniform ranges for initial-state randomization.
|
||||
|
||||
Returns (qpos_lo, qpos_hi, qvel_lo, qvel_hi) — offsets added to the
|
||||
model's default state on every reset. All runners (CPU MuJoCo and
|
||||
MJX) sample from these, so initial-state distributions stay
|
||||
identical across backends. Default: small ±0.05 perturbation.
|
||||
"""
|
||||
return (
|
||||
np.full(nq, -0.05), np.full(nq, 0.05),
|
||||
np.full(nv, -0.05), np.full(nv, 0.05),
|
||||
)
|
||||
|
||||
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
|
||||
"""Check whether the physical robot has settled enough to start an episode.
|
||||
|
||||
Used by the SerialRunner after driving to center and waiting for the
|
||||
pendulum. Default: always ready (sim doesn't need settling).
|
||||
"""
|
||||
return True
|
||||
|
||||
22
src/core/registry.py
Normal file
22
src/core/registry.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Shared env registry and builder used by train.py and viz.py."""
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
env_dict.pop("hpo", None) # HPO range metadata — not an env config field
|
||||
return env_cls(config_cls(**env_dict))
|
||||
242
src/core/robot.py
Normal file
242
src/core/robot.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Robot hardware configuration — loaded from robot.yaml next to the URDF.
|
||||
|
||||
Separates robot hardware (actuators, joint tuning) from task config
|
||||
(rewards, episode length) and from the URDF (clean CAD export).
|
||||
|
||||
Usage:
|
||||
robot = load_robot_config(Path("assets/rotary_cartpole"))
|
||||
# robot.urdf_path → resolved absolute path to the URDF
|
||||
# robot.actuators → list of ActuatorConfig
|
||||
# robot.joints → dict of per-joint overrides
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def _as_pair(val) -> tuple[float, float]:
|
||||
"""Convert scalar or [pos, neg] list to (pos, neg) tuple."""
|
||||
if isinstance(val, (list, tuple)) and len(val) == 2:
|
||||
return (float(val[0]), float(val[1]))
|
||||
return (float(val), float(val))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorConfig:
|
||||
"""Motor/actuator attached to a joint.
|
||||
|
||||
Asymmetric fields use (positive_dir, negative_dir) tuples.
|
||||
A scalar in YAML is expanded to a symmetric pair.
|
||||
|
||||
type:
|
||||
motor — direct torque control (ctrl = normalised torque)
|
||||
position — PD position servo (ctrl = target angle, needs kp)
|
||||
velocity — P velocity servo (ctrl = target velocity, needs kp)
|
||||
"""
|
||||
joint: str = ""
|
||||
type: str = "motor"
|
||||
gear: tuple[float, float] = (1.0, 1.0) # torque constant (pos, neg)
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0) # (lower, upper) control bounds
|
||||
deadzone: tuple[float, float] = (0.0, 0.0) # min |ctrl| for torque (pos, neg)
|
||||
damping: tuple[float, float] = (0.0, 0.0) # viscous damping (pos, neg)
|
||||
frictionloss: tuple[float, float] = (0.0, 0.0) # Coulomb friction (pos, neg)
|
||||
kp: float = 0.0 # proportional gain (position / velocity actuators)
|
||||
kv: float = 0.0 # derivative gain (position actuators)
|
||||
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
|
||||
viscous_quadratic: float = 0.0 # velocity² drag coefficient
|
||||
back_emf_gain: float = 0.0 # back-EMF torque reduction
|
||||
stribeck_friction_boost: float = 0.0 # extra static friction at low speed (N·m)
|
||||
stribeck_vel: float = 2.0 # Stribeck decay velocity (rad/s)
|
||||
action_bias: float = 0.0 # additive ctrl bias (driver asymmetry)
|
||||
|
||||
@property
|
||||
def gear_avg(self) -> float:
|
||||
return (self.gear[0] + self.gear[1]) / 2.0
|
||||
|
||||
@property
|
||||
def has_motor_model(self) -> bool:
|
||||
"""True if this actuator needs the runtime motor model."""
|
||||
return (
|
||||
self.gear[0] != self.gear[1]
|
||||
or self.deadzone != (0.0, 0.0)
|
||||
or self.damping != (0.0, 0.0)
|
||||
or self.frictionloss != (0.0, 0.0)
|
||||
or self.viscous_quadratic > 0
|
||||
or self.back_emf_gain > 0
|
||||
or self.stribeck_friction_boost > 0
|
||||
or self.action_bias != 0.0
|
||||
)
|
||||
|
||||
def transform_ctrl(self, ctrl: float) -> float:
|
||||
"""Clip to ctrl_range, then apply bias, deadzone and gear compensation.
|
||||
|
||||
Must stay in lock-step with the vectorised JAX version in
|
||||
``src/runners/mjx.py`` (step_fn) — sysid fits parameters against
|
||||
THIS function, so any drift breaks the identified model.
|
||||
"""
|
||||
# Clip to ctrl_range first (mirrors firmware PWM saturation).
|
||||
ctrl = max(self.ctrl_range[0], min(self.ctrl_range[1], ctrl))
|
||||
|
||||
# Additive driver bias (e.g. H-bridge asymmetry).
|
||||
ctrl += self.action_bias
|
||||
|
||||
# Deadzone
|
||||
dz_pos, dz_neg = self.deadzone
|
||||
if ctrl >= 0 and ctrl < dz_pos:
|
||||
return 0.0
|
||||
if ctrl < 0 and ctrl > -dz_neg:
|
||||
return 0.0
|
||||
|
||||
# Gear compensation: rescale so ctrl × gear_avg ≈ action × gear_dir
|
||||
gear_avg = self.gear_avg
|
||||
if gear_avg > 1e-8:
|
||||
gear_dir = self.gear[0] if ctrl >= 0 else self.gear[1]
|
||||
ctrl *= gear_dir / gear_avg
|
||||
|
||||
return ctrl
|
||||
|
||||
def compute_motor_force(self, vel: float, ctrl: float,
|
||||
friction_scale: float = 1.0,
|
||||
damping_scale: float = 1.0) -> float:
|
||||
"""Asymmetric friction (Coulomb + Stribeck), damping, drag, back-EMF.
|
||||
|
||||
``friction_scale`` / ``damping_scale`` multiply the friction and
|
||||
viscous-damping terms for per-env domain randomization
|
||||
(1.0 = no randomization, the default used by sysid).
|
||||
"""
|
||||
torque = 0.0
|
||||
|
||||
# Coulomb + Stribeck friction (direction-dependent). The Stribeck
|
||||
# boost adds extra friction at low speed that decays as exp(-(v/vs)²)
|
||||
# — crucial for cheap brushed motors near standstill.
|
||||
fl_pos, fl_neg = self.frictionloss
|
||||
if abs(vel) > 1e-6:
|
||||
fl = fl_pos if vel > 0 else fl_neg
|
||||
if self.stribeck_friction_boost > 0:
|
||||
fl += self.stribeck_friction_boost * math.exp(
|
||||
-((abs(vel) / self.stribeck_vel) ** 2)
|
||||
)
|
||||
torque -= math.copysign(fl * friction_scale, vel)
|
||||
|
||||
# Viscous damping (direction-dependent)
|
||||
damp = (self.damping[0] if vel > 0 else self.damping[1]) * damping_scale
|
||||
torque -= damp * vel
|
||||
|
||||
# Quadratic velocity drag
|
||||
if self.viscous_quadratic > 0:
|
||||
torque -= self.viscous_quadratic * vel * abs(vel)
|
||||
|
||||
# Back-EMF torque reduction
|
||||
if self.back_emf_gain > 0 and abs(ctrl) > 1e-6:
|
||||
torque -= self.back_emf_gain * vel * math.copysign(1.0, ctrl)
|
||||
|
||||
return max(-10.0, min(10.0, torque))
|
||||
|
||||
def transform_action(self, action):
|
||||
"""Vectorised clip + bias + deadzone + gear compensation (torch batch).
|
||||
|
||||
Must produce the same result as ``transform_ctrl`` element-wise.
|
||||
"""
|
||||
action = action.clamp(self.ctrl_range[0], self.ctrl_range[1])
|
||||
action = action + self.action_bias
|
||||
|
||||
dz_pos, dz_neg = self.deadzone
|
||||
if dz_pos > 0 or dz_neg > 0:
|
||||
pos_dead = (action >= 0) & (action < dz_pos)
|
||||
neg_dead = (action < 0) & (action > -dz_neg)
|
||||
action = action.masked_fill(pos_dead | neg_dead, 0.0)
|
||||
|
||||
gear_avg = self.gear_avg
|
||||
if gear_avg > 1e-8 and self.gear[0] != self.gear[1]:
|
||||
pos = action >= 0
|
||||
action = torch.where(
|
||||
pos, action * (self.gear[0] / gear_avg),
|
||||
action * (self.gear[1] / gear_avg),
|
||||
)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JointConfig:
|
||||
"""Per-joint overrides applied on top of the URDF values."""
|
||||
damping: float | None = None
|
||||
armature: float | None = None # reflected rotor inertia (kg·m²)
|
||||
frictionloss: float | None = None # Coulomb/dry friction torque (N·m)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RobotConfig:
|
||||
"""Complete robot hardware description."""
|
||||
urdf_path: Path = dataclasses.field(default_factory=lambda: Path())
|
||||
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
|
||||
joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
def load_robot_config(robot_dir: str | Path) -> RobotConfig:
|
||||
"""Load robot.yaml from a directory and resolve the URDF path.
|
||||
|
||||
Expected layout:
|
||||
robot_dir/
|
||||
robot.yaml ← hardware config
|
||||
some_robot.urdf ← CAD export
|
||||
meshes/ ← optional mesh files
|
||||
"""
|
||||
robot_dir = Path(robot_dir).resolve()
|
||||
yaml_path = robot_dir / "robot.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"Robot config not found: {yaml_path}")
|
||||
|
||||
raw = yaml.safe_load(yaml_path.read_text())
|
||||
|
||||
# Resolve URDF path relative to robot.yaml directory
|
||||
urdf_filename = raw.get("urdf", "")
|
||||
if not urdf_filename:
|
||||
raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}")
|
||||
urdf_path = robot_dir / urdf_filename
|
||||
if not urdf_path.exists():
|
||||
raise FileNotFoundError(f"URDF not found: {urdf_path}")
|
||||
|
||||
# Parse actuators — ignore unknown keys (newer sysid exports may add
|
||||
# fields before the loader learns about them) instead of crashing.
|
||||
known_fields = {f.name for f in dataclasses.fields(ActuatorConfig)}
|
||||
actuators = []
|
||||
for a in raw.get("actuators", []):
|
||||
unknown = set(a) - known_fields
|
||||
if unknown:
|
||||
log.warning(
|
||||
"robot_yaml_unknown_actuator_keys",
|
||||
keys=sorted(unknown), file=str(yaml_path),
|
||||
)
|
||||
a = {k: v for k, v in a.items() if k in known_fields}
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
for key in ("gear", "deadzone", "damping", "frictionloss"):
|
||||
if key in a:
|
||||
a[key] = _as_pair(a[key])
|
||||
actuators.append(ActuatorConfig(**a))
|
||||
|
||||
# Parse joint overrides
|
||||
joints = {}
|
||||
for name, jcfg in raw.get("joints", {}).items():
|
||||
joints[name] = JointConfig(**jcfg)
|
||||
|
||||
config = RobotConfig(
|
||||
urdf_path=urdf_path,
|
||||
actuators=actuators,
|
||||
joints=joints,
|
||||
)
|
||||
|
||||
log.debug("robot_config_loaded", robot_dir=str(robot_dir),
|
||||
urdf=urdf_filename, num_actuators=len(actuators),
|
||||
joint_overrides=list(joints.keys()))
|
||||
|
||||
return config
|
||||
@@ -1,9 +1,12 @@
|
||||
import dataclasses
|
||||
import abc
|
||||
from typing import Any, Generic, TypeVar
|
||||
from src.core.env import BaseEnv
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -11,12 +14,31 @@ T = TypeVar("T")
|
||||
class BaseRunnerConfig:
|
||||
num_envs: int = 1
|
||||
device: str = "cpu"
|
||||
history_length: int = 0 # 0 = plain obs, >0 = append (obs, action) history
|
||||
|
||||
# ── Domain randomization (sim-to-real) ─────────────────────────
|
||||
# Empty dict = disabled (every field below is a no-op). Supported keys:
|
||||
# qpos_noise_std: float — Gaussian sensor noise on joint angles (rad)
|
||||
# qvel_noise_std: float — Gaussian sensor noise on joint velocities (rad/s)
|
||||
# action_delay_steps: [lo, hi] — per-env integer control-step latency
|
||||
# friction_scale: [lo, hi] — per-env multiplier on Coulomb friction
|
||||
# damping_scale: [lo, hi] — per-env multiplier on viscous damping
|
||||
# torque_scale: [lo, hi] — per-env multiplier on applied motor torque
|
||||
# With history_length > 0 the policy can implicitly infer the sampled
|
||||
# dynamics from the recent (obs, action) window — end-to-end adaptation.
|
||||
domain_rand: dict = dataclasses.field(default_factory=dict)
|
||||
|
||||
class BaseRunner(abc.ABC, Generic[T]):
|
||||
def __init__(self, env: BaseEnv, config: T) -> None:
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
# Resolve "auto" device before anything uses it
|
||||
if getattr(self.config, "device", None) == "auto":
|
||||
self.config.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self._last_actions: torch.Tensor | None = None
|
||||
|
||||
self._sim_initialize(config)
|
||||
|
||||
self.observation_space = self.env.observation_space
|
||||
@@ -27,6 +49,28 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
self.config.num_envs, dtype=torch.long, device=self.config.device
|
||||
)
|
||||
|
||||
# ── Domain randomization (latency / sensor noise / dynamics) ─
|
||||
self._setup_domain_rand()
|
||||
|
||||
# ── History buffer (implicit adaptation input) ────────────
|
||||
self._history_len: int = getattr(self.config, "history_length", 0)
|
||||
if self._history_len > 0:
|
||||
obs_dim = self.observation_space.shape[0]
|
||||
act_dim = self.action_space.shape[0]
|
||||
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
|
||||
# Ring buffer: (num_envs, history_length, obs_dim + act_dim)
|
||||
self._history_buf = torch.zeros(
|
||||
self.config.num_envs, self._history_len, self._history_step_dim,
|
||||
device=self.config.device,
|
||||
)
|
||||
|
||||
# Policy obs = [raw_obs, history_flat]
|
||||
from gymnasium import spaces
|
||||
aug_dim = obs_dim + self._history_len * self._history_step_dim
|
||||
self.observation_space = spaces.Box(
|
||||
low=-torch.inf, high=torch.inf, shape=(aug_dim,),
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def num_envs(self) -> int:
|
||||
@@ -47,51 +91,243 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reset the given envs; return FULL-batch (num_envs, nq/nv) state.
|
||||
|
||||
Returning the full batch (not just the reset envs) lets GPU
|
||||
backends hand back zero-copy views without host synchronisation —
|
||||
the caller indexes the reset rows itself.
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_close(self) -> None:
|
||||
...
|
||||
"""Release simulator resources. Override for extra cleanup."""
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
# ── Domain randomization ─────────────────────────────────────
|
||||
|
||||
_SCALE_FIELDS = ("friction_scale", "damping_scale", "torque_scale")
|
||||
|
||||
def _setup_domain_rand(self) -> None:
|
||||
"""Parse the domain_rand config into per-env buffers.
|
||||
|
||||
All buffers are no-ops when ``domain_rand`` is empty: scales are 1.0,
|
||||
delay is 0 and noise std is 0.
|
||||
"""
|
||||
dr = dict(getattr(self.config, "domain_rand", {}) or {})
|
||||
n = self.config.num_envs
|
||||
dev = self.config.device
|
||||
|
||||
# Fixed (not per-env) Gaussian sensor noise.
|
||||
self._qpos_noise_std = float(dr.get("qpos_noise_std", 0.0))
|
||||
self._qvel_noise_std = float(dr.get("qvel_noise_std", 0.0))
|
||||
|
||||
# Per-env multiplicative dynamics scales (applied by the sim runner).
|
||||
self._dr_scales: dict[str, torch.Tensor] = {
|
||||
f: torch.ones(n, device=dev) for f in self._SCALE_FIELDS
|
||||
}
|
||||
self._dr_scale_ranges: dict[str, tuple[float, float]] = {}
|
||||
for f in self._SCALE_FIELDS:
|
||||
rng = dr.get(f)
|
||||
if rng:
|
||||
self._dr_scale_ranges[f] = (float(rng[0]), float(rng[1]))
|
||||
|
||||
# Per-env integer action delay (in control steps).
|
||||
self._dr_delay = torch.zeros(n, dtype=torch.long, device=dev)
|
||||
delay_range = dr.get("action_delay_steps")
|
||||
if delay_range:
|
||||
self._delay_range = (int(delay_range[0]), int(delay_range[1]))
|
||||
self._max_delay = int(delay_range[1])
|
||||
else:
|
||||
self._delay_range = (0, 0)
|
||||
self._max_delay = 0
|
||||
|
||||
# Action-delay ring buffer: (num_envs, max_delay + 1, act_dim).
|
||||
if self._max_delay > 0:
|
||||
act_dim = self.env.action_space.shape[0]
|
||||
self._action_buf = torch.zeros(
|
||||
n, self._max_delay + 1, act_dim, device=dev,
|
||||
)
|
||||
|
||||
def _resample_domain_rand(self, env_ids: torch.Tensor) -> None:
|
||||
"""Sample fresh per-env DR factors (call on every (re)set)."""
|
||||
if env_ids.numel() == 0:
|
||||
return
|
||||
dev = self.config.device
|
||||
for name, (lo, hi) in self._dr_scale_ranges.items():
|
||||
vals = torch.rand(env_ids.numel(), device=dev) * (hi - lo) + lo
|
||||
self._dr_scales[name][env_ids] = vals
|
||||
if self._max_delay > 0:
|
||||
self._dr_delay[env_ids] = torch.randint(
|
||||
self._delay_range[0], self._delay_range[1] + 1,
|
||||
(env_ids.numel(),), device=dev,
|
||||
)
|
||||
|
||||
def _reset_action_buffer(self, env_ids: torch.Tensor) -> None:
|
||||
if self._max_delay > 0:
|
||||
self._action_buf[env_ids] = 0.0
|
||||
|
||||
def _apply_action_delay(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Return the per-env delayed action that the simulator should apply.
|
||||
|
||||
The policy's commanded action is what gets stored in history; only
|
||||
the action handed to ``_sim_step`` is delayed.
|
||||
"""
|
||||
if self._max_delay <= 0:
|
||||
return actions
|
||||
self._action_buf = torch.roll(self._action_buf, 1, dims=1)
|
||||
self._action_buf[:, 0] = actions
|
||||
idx = torch.arange(self.num_envs, device=self.device)
|
||||
return self._action_buf[idx, self._dr_delay]
|
||||
|
||||
def _add_sensor_noise(
|
||||
self, qpos: torch.Tensor, qvel: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self._qpos_noise_std > 0:
|
||||
qpos = qpos + torch.randn_like(qpos) * self._qpos_noise_std
|
||||
if self._qvel_noise_std > 0:
|
||||
qvel = qvel + torch.randn_like(qvel) * self._qvel_noise_std
|
||||
return qpos, qvel
|
||||
|
||||
def _compute_obs(self, qpos: torch.Tensor, qvel: torch.Tensor) -> torch.Tensor:
|
||||
"""Observation the policy sees — built from noisy (sensor) state."""
|
||||
nqpos, nqvel = self._add_sensor_noise(qpos, qvel)
|
||||
return self.env.compute_observations(self.env.build_state(nqpos, nqvel))
|
||||
|
||||
# ── Observation augmentation ─────────────────────────────────
|
||||
|
||||
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
"""Append the flattened (obs, action) history when enabled."""
|
||||
if self._history_len <= 0:
|
||||
return obs
|
||||
hist_flat = self._history_buf.reshape(obs.shape[0], -1)
|
||||
return torch.cat([obs, hist_flat], dim=-1)
|
||||
|
||||
def _push_history(self, obs: torch.Tensor, actions: torch.Tensor,
|
||||
env_ids: torch.Tensor | None = None) -> None:
|
||||
"""Push (obs, action) into the ring buffer (shift left, append right)."""
|
||||
if self._history_len <= 0:
|
||||
return
|
||||
step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1)
|
||||
if env_ids is None:
|
||||
# All envs.
|
||||
self._history_buf = torch.roll(self._history_buf, -1, dims=1)
|
||||
self._history_buf[:, -1] = step
|
||||
else:
|
||||
self._history_buf[env_ids] = torch.roll(
|
||||
self._history_buf[env_ids], -1, dims=1
|
||||
)
|
||||
self._history_buf[env_ids, -1] = step[env_ids]
|
||||
|
||||
def _reset_history(self, env_ids: torch.Tensor) -> None:
|
||||
"""Zero the history buffer for reset envs."""
|
||||
if self._history_len > 0:
|
||||
self._history_buf[env_ids] = 0.0
|
||||
|
||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
self._resample_domain_rand(all_ids)
|
||||
self._reset_action_buffer(all_ids)
|
||||
qpos, qvel = self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
self._reset_history(all_ids)
|
||||
|
||||
obs = self._compute_obs(qpos, qvel)
|
||||
return self._augment_obs(obs), {}
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
return obs, {}
|
||||
|
||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
qpos, qvel = self._sim_step(actions)
|
||||
prev_actions = (
|
||||
self._last_actions
|
||||
if self._last_actions is not None
|
||||
else torch.zeros_like(actions)
|
||||
)
|
||||
self._last_actions = actions
|
||||
# Latency: the simulator applies a (per-env) delayed action.
|
||||
sim_actions = self._apply_action_delay(actions)
|
||||
qpos, qvel = self._sim_step(sim_actions)
|
||||
self.step_counts += 1
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
rewards = self.env.compute_rewards(state, actions)
|
||||
terminated = self.env.compute_terminations(state)
|
||||
# Reward / termination use the TRUE state (no sensor noise) so the
|
||||
# learning signal and safety checks stay clean.
|
||||
clean_state = self.env.build_state(qpos, qvel)
|
||||
rewards = self.env.compute_rewards(clean_state, actions, prev_actions)
|
||||
terminated = self.env.compute_terminations(clean_state)
|
||||
truncated = self.env.compute_truncations(self.step_counts)
|
||||
|
||||
# The observation the policy sees is built from the NOISY sensor state.
|
||||
obs = self._compute_obs(qpos, qvel)
|
||||
|
||||
# Push current (obs, action) into history before augmenting.
|
||||
self._push_history(obs, actions)
|
||||
|
||||
info: dict[str, Any] = {}
|
||||
|
||||
done = terminated | truncated
|
||||
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if done_ids.numel() > 0:
|
||||
info["final_observations"] = obs[done_ids].clone()
|
||||
info["final_observations"] = self._augment_obs(obs)[done_ids].clone()
|
||||
info["final_env_ids"] = done_ids.clone()
|
||||
|
||||
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
||||
# New episode → fresh dynamics + cleared latency buffer.
|
||||
self._resample_domain_rand(done_ids)
|
||||
self._reset_action_buffer(done_ids)
|
||||
full_qpos, full_qvel = self._sim_reset(done_ids)
|
||||
self.step_counts[done_ids] = 0
|
||||
self._reset_history(done_ids)
|
||||
|
||||
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
||||
obs[done_ids] = self.env.compute_observations(reset_state)
|
||||
# _sim_reset returns the full batch — index the reset rows here.
|
||||
obs[done_ids] = self._compute_obs(
|
||||
full_qpos[done_ids], full_qvel[done_ids],
|
||||
)
|
||||
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
raise NotImplementedError("Render method not implemented for this runner.")
|
||||
|
||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Return a raw RGB frame. Override in subclass."""
|
||||
raise NotImplementedError("Render not implemented for this runner.")
|
||||
|
||||
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Render frame with action overlay."""
|
||||
frame = self._render_frame(env_idx)
|
||||
if self._last_actions is not None:
|
||||
ctrl = float(self._last_actions[env_idx, 0].clamp(-1.0, 1.0))
|
||||
_draw_action_overlay(frame, ctrl)
|
||||
return frame
|
||||
|
||||
def close(self) -> None:
|
||||
self._sim_close()
|
||||
self._sim_close()
|
||||
|
||||
|
||||
def _draw_action_overlay(frame: np.ndarray, action: float) -> None:
|
||||
"""Draw an action bar on a rendered frame (no OpenCV needed).
|
||||
|
||||
Bar is centered horizontally: green to the right (+), red to the left (-).
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
bar_y = h - 30
|
||||
bar_h = 16
|
||||
bar_x_center = w // 2
|
||||
bar_half_w = w // 4
|
||||
bar_x_left = bar_x_center - bar_half_w
|
||||
bar_x_right = bar_x_center + bar_half_w
|
||||
|
||||
# Background (dark grey)
|
||||
frame[bar_y:bar_y + bar_h, bar_x_left:bar_x_right] = [40, 40, 40]
|
||||
|
||||
# Filled bar
|
||||
fill_len = int(abs(action) * bar_half_w)
|
||||
if action > 0:
|
||||
color = [60, 200, 60] # green
|
||||
x0 = bar_x_center
|
||||
x1 = min(bar_x_center + fill_len, bar_x_right)
|
||||
else:
|
||||
color = [200, 60, 60] # red
|
||||
x1 = bar_x_center
|
||||
x0 = max(bar_x_center - fill_len, bar_x_left)
|
||||
frame[bar_y:bar_y + bar_h, x0:x1] = color
|
||||
|
||||
# Center tick mark (white)
|
||||
frame[bar_y:bar_y + bar_h, bar_x_center - 1:bar_x_center + 1] = [255, 255, 255]
|
||||
@@ -1,53 +0,0 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from gymnasium import spaces
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CartPoleState:
|
||||
cart_pos: torch.float # (num_envs,)
|
||||
cart_vel: torch.float # (num_envs,)
|
||||
pole_angle: torch.float # (num_envs,)
|
||||
pole_vel: torch.float # (num_envs,)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CartPoleConfig(BaseEnvConfig):
|
||||
"""CartPole task config. All values come from Hydra YAML."""
|
||||
angle_threshold: float = 0.418 # ~24 degrees
|
||||
cart_limit: float = 2.4
|
||||
reward_alive: float = 1.0
|
||||
reward_pole_upright_scale: float = 1.0
|
||||
reward_action_penalty_scale: float = 0.01
|
||||
|
||||
class CartPoleEnv(BaseEnv[CartPoleConfig]):
|
||||
def __init__(self, config: CartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@property
|
||||
def observation_space(self) -> torch.Tensor:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> torch.Tensor:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
|
||||
return CartPoleState(
|
||||
cart_pos=qpos[:, 0],
|
||||
cart_vel=qvel[:, 0],
|
||||
pole_angle=qpos[:, 1],
|
||||
pole_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
|
||||
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
|
||||
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
|
||||
return self.config.reward_alive + upright - action_penalty
|
||||
|
||||
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
|
||||
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
|
||||
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
|
||||
return pole_fallen | cart_out_of_bounds
|
||||
181
src/envs/rotary_cartpole.py
Normal file
181
src/envs/rotary_cartpole.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import dataclasses
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleState:
|
||||
motor_angle: torch.Tensor # (num_envs,)
|
||||
motor_vel: torch.Tensor # (num_envs,)
|
||||
pendulum_angle: torch.Tensor # (num_envs,)
|
||||
pendulum_vel: torch.Tensor # (num_envs,)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleConfig(BaseEnvConfig):
|
||||
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
||||
|
||||
The motor rotates the arm horizontally; the pendulum swings freely
|
||||
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||
"""
|
||||
# Reward shaping
|
||||
reward_upright_scale: float = 1.0 # upright reward ∈ [0, scale]
|
||||
alive_bonus: float = 0.25 # per-step survival bonus (must stay alive > die)
|
||||
balance_bonus: float = 2.0 # extra reward for upright AND still (beats spinning)
|
||||
balance_vel_scale: float = 0.5 # decay rate of the bonus with pendulum speed
|
||||
motor_vel_penalty: float = 0.01 # penalise high motor angular velocity
|
||||
motor_angle_penalty: float = 0.05 # penalise deviation from centre
|
||||
action_penalty: float = 0.05 # penalise large actions (energy cost)
|
||||
action_rate_penalty: float = 0.01 # penalise action changes (smoothness —
|
||||
# critical with ~100 ms real motor lag)
|
||||
|
||||
# ── Initial state randomisation ──
|
||||
pendulum_init_range_deg: float = 180.0 # pendulum starts in [-range, +range]
|
||||
|
||||
# ── Software safety limit (env-level, on top of URDF + hardware) ──
|
||||
motor_angle_limit_deg: float = 90.0 # terminate episode if exceeded
|
||||
|
||||
|
||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
"""Furuta pendulum / rotary inverted pendulum environment.
|
||||
|
||||
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
||||
|
||||
Observations (6):
|
||||
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
||||
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
||||
|
||||
Actions (1):
|
||||
Torque applied to the motor_joint (normalised to [-1, 1]).
|
||||
"""
|
||||
|
||||
def __init__(self, config: RotaryCartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# ── Spaces ───────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
# ── State building ───────────────────────────────────────────
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=qpos[:, 0],
|
||||
motor_vel=qvel[:, 0],
|
||||
pendulum_angle=qpos[:, 1],
|
||||
pendulum_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
# ── Observations ─────────────────────────────────────────────
|
||||
|
||||
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
obs = [
|
||||
torch.sin(state.motor_angle),
|
||||
torch.cos(state.motor_angle),
|
||||
torch.sin(state.pendulum_angle),
|
||||
torch.cos(state.pendulum_angle),
|
||||
state.motor_vel,
|
||||
state.pendulum_vel,
|
||||
]
|
||||
return torch.stack(obs, dim=-1)
|
||||
|
||||
# ── Rewards ──────────────────────────────────────────────────
|
||||
|
||||
def compute_rewards(
|
||||
self,
|
||||
state: RotaryCartPoleState,
|
||||
actions: torch.Tensor,
|
||||
prev_actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Upright shaping ∈ [0, 1]: 0 hanging down (θ=0), 1 fully upright (θ=π).
|
||||
# Non-negative by design so *surviving* always beats ending the episode early
|
||||
# (otherwise the optimum is to slam the arm into the ±limit — "suicide policy").
|
||||
upright = 0.5 * (1.0 - torch.cos(state.pendulum_angle))
|
||||
|
||||
# Balanced bonus: large ONLY when near the top AND nearly still. A freely
|
||||
# spinning pendulum passes through the top at high speed, so stillness≈0 and
|
||||
# it earns ~none of this — making true balancing strictly dominate the
|
||||
# "just keep spinning in full loops" local optimum.
|
||||
stillness = torch.exp(-self.config.balance_vel_scale * state.pendulum_vel.pow(2))
|
||||
balance = self.config.balance_bonus * upright * stillness
|
||||
|
||||
# Per-step alive bonus keeps a not-yet-upright step net-positive despite
|
||||
# penalties, so the −10 termination penalty is genuinely a deterrent.
|
||||
reward = (upright * self.config.reward_upright_scale
|
||||
+ balance
|
||||
+ self.config.alive_bonus)
|
||||
|
||||
# Penalise fast motor spinning (discourages violent oscillation)
|
||||
reward = reward - self.config.motor_vel_penalty * state.motor_vel.pow(2)
|
||||
|
||||
# Penalise motor deviation from centre (keep arm near zero)
|
||||
reward = reward - self.config.motor_angle_penalty * state.motor_angle.pow(2)
|
||||
|
||||
# Penalise large actions (energy efficiency / smoother control)
|
||||
reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2)
|
||||
|
||||
# Penalise rapid action changes — a jittery policy is unrealisable
|
||||
# through the real motor's ~100 ms lag and excites unmodeled dynamics.
|
||||
action_rate = (actions - prev_actions).squeeze(-1).pow(2)
|
||||
reward = reward - self.config.action_rate_penalty * action_rate
|
||||
|
||||
# Penalty for exceeding motor angle limit (episode also terminates)
|
||||
limit_rad = math.radians(self.config.motor_angle_limit_deg)
|
||||
exceeded = state.motor_angle.abs() >= limit_rad
|
||||
reward = torch.where(exceeded, torch.tensor(-10.0, device=reward.device), reward)
|
||||
|
||||
return reward
|
||||
|
||||
# ── Initial state randomization ──────────────────────────────
|
||||
|
||||
def initial_state_ranges(
|
||||
self, nq: int, nv: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Small motor perturbation; wide pendulum angle (swing-up task)."""
|
||||
qpos_lo = np.full(nq, -0.05)
|
||||
qpos_hi = np.full(nq, 0.05)
|
||||
qvel_lo = np.full(nv, -0.05)
|
||||
qvel_hi = np.full(nv, 0.05)
|
||||
pend_range = math.radians(self.config.pendulum_init_range_deg)
|
||||
if pend_range > 0 and nq >= 2:
|
||||
qpos_lo[1] = -pend_range
|
||||
qpos_hi[1] = pend_range
|
||||
return qpos_lo, qpos_hi, qvel_lo, qvel_hi
|
||||
|
||||
# ── Terminations ─────────────────────────────────────────────
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# Software safety: terminate if motor angle exceeds limit.
|
||||
limit_rad = math.radians(self.config.motor_angle_limit_deg)
|
||||
exceeded = state.motor_angle.abs() >= limit_rad
|
||||
return exceeded
|
||||
|
||||
# ── Reset readiness (for SerialRunner) ───────────────────────
|
||||
|
||||
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
|
||||
"""Pendulum must be hanging still and motor near center."""
|
||||
motor_angle = float(qpos[0, 0])
|
||||
pend_angle = float(qpos[0, 1])
|
||||
motor_vel = float(qvel[0, 0])
|
||||
pend_vel = float(qvel[0, 1])
|
||||
|
||||
# Pendulum near hanging-down (angle ~0) and slow
|
||||
angle_ok = abs(pend_angle) < math.radians(2.0)
|
||||
vel_ok = abs(pend_vel) < math.radians(5.0)
|
||||
# Motor near center and slow
|
||||
motor_ok = abs(motor_angle) < math.radians(5.0)
|
||||
motor_vel_ok = abs(motor_vel) < math.radians(10.0)
|
||||
return angle_ok and vel_ok and motor_ok and motor_vel_ok
|
||||
|
||||
|
||||
1
src/hpo/__init__.py
Normal file
1
src/hpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Hyperparameter optimization — SMAC3 + ClearML Successive Halving."""
|
||||
684
src/hpo/smac3.py
Normal file
684
src/hpo/smac3.py
Normal file
@@ -0,0 +1,684 @@
|
||||
# Requires: pip install smac==2.0.0 ConfigSpace==0.4.20
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from clearml import Task
|
||||
from clearml.automation.optimization import Objective, SearchStrategy
|
||||
from clearml.automation.parameters import Parameter
|
||||
from clearml.backend_interface.session import SendError
|
||||
from ConfigSpace import (
|
||||
CategoricalHyperparameter,
|
||||
ConfigurationSpace,
|
||||
UniformFloatHyperparameter,
|
||||
UniformIntegerHyperparameter,
|
||||
)
|
||||
from smac import MultiFidelityFacade
|
||||
from smac.intensifier.successive_halving import SuccessiveHalving
|
||||
from smac.runhistory.dataclasses import TrialInfo, TrialValue
|
||||
from smac.scenario import Scenario
|
||||
|
||||
|
||||
def retry_on_error(max_retries=5, initial_delay=2.0, backoff=2.0, exceptions=(Exception,)):
|
||||
"""Decorator to retry a function on exception with exponential backoff."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
delay = initial_delay
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions:
|
||||
if attempt == max_retries - 1:
|
||||
return None # Return None instead of raising
|
||||
time.sleep(delay)
|
||||
delay *= backoff
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _encode_param_name(name: str) -> str:
|
||||
"""Encode parameter name for ConfigSpace (replace / with __SLASH__)"""
|
||||
return name.replace("/", "__SLASH__")
|
||||
|
||||
|
||||
def _decode_param_name(name: str) -> str:
|
||||
"""Decode parameter name back to original (replace __SLASH__ with /)"""
|
||||
return name.replace("__SLASH__", "/")
|
||||
|
||||
|
||||
def _convert_param_to_cs(param: Parameter):
|
||||
"""
|
||||
Convert a ClearML Parameter into a ConfigSpace hyperparameter,
|
||||
adapted to ConfigSpace>=1.x (no more 'q' argument).
|
||||
"""
|
||||
# Encode the name to avoid ConfigSpace issues with special chars like '/'
|
||||
name = _encode_param_name(param.name)
|
||||
|
||||
# Categorical / discrete list
|
||||
if hasattr(param, "values"):
|
||||
return CategoricalHyperparameter(name=name, choices=list(param.values))
|
||||
|
||||
# Numeric range (float or int)
|
||||
if hasattr(param, "min_value") and hasattr(param, "max_value"):
|
||||
min_val = param.min_value
|
||||
max_val = param.max_value
|
||||
|
||||
# Check if this should be treated as integer
|
||||
if isinstance(min_val, int) and isinstance(max_val, int):
|
||||
log = getattr(param, "log_scale", False)
|
||||
|
||||
# Check for step_size for quantization
|
||||
if hasattr(param, "step_size"):
|
||||
sv = int(param.step_size)
|
||||
if sv != 1:
|
||||
# emulate quantization by explicit list of values
|
||||
choices = list(range(min_val, max_val + 1, sv))
|
||||
return CategoricalHyperparameter(name=name, choices=choices)
|
||||
|
||||
# Simple uniform integer range
|
||||
return UniformIntegerHyperparameter(name=name, lower=min_val, upper=max_val, log=log)
|
||||
else:
|
||||
# Treat as float
|
||||
lower, upper = float(min_val), float(max_val)
|
||||
log = getattr(param, "log_scale", False)
|
||||
return UniformFloatHyperparameter(name=name, lower=lower, upper=upper, log=log)
|
||||
|
||||
raise ValueError(f"Unsupported Parameter type: {type(param)}")
|
||||
|
||||
|
||||
class OptimizerSMAC(SearchStrategy):
|
||||
"""
|
||||
SMAC3-based hyperparameter optimizer, matching OptimizerBOHB interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id: str,
|
||||
hyper_parameters: Sequence[Parameter],
|
||||
objective_metric: Objective,
|
||||
execution_queue: str,
|
||||
num_concurrent_workers: int,
|
||||
min_iteration_per_job: int,
|
||||
max_iteration_per_job: int,
|
||||
total_max_jobs: int,
|
||||
pool_period_min: float = 2.0,
|
||||
time_limit_per_job: float | None = None,
|
||||
compute_time_limit: float | None = None,
|
||||
**smac_kwargs: Any,
|
||||
):
|
||||
# Initialize base SearchStrategy
|
||||
super().__init__(
|
||||
base_task_id=base_task_id,
|
||||
hyper_parameters=hyper_parameters,
|
||||
objective_metric=objective_metric,
|
||||
execution_queue=execution_queue,
|
||||
num_concurrent_workers=num_concurrent_workers,
|
||||
pool_period_min=pool_period_min,
|
||||
time_limit_per_job=time_limit_per_job,
|
||||
compute_time_limit=compute_time_limit,
|
||||
min_iteration_per_job=min_iteration_per_job,
|
||||
max_iteration_per_job=max_iteration_per_job,
|
||||
total_max_jobs=total_max_jobs,
|
||||
)
|
||||
|
||||
# Expose for internal use (access private attributes from base class)
|
||||
self.execution_queue = self._execution_queue
|
||||
self.min_iterations = min_iteration_per_job
|
||||
self.max_iterations = max_iteration_per_job
|
||||
self.num_concurrent_workers = self._num_concurrent_workers # Fix: access private attribute
|
||||
|
||||
# Objective details
|
||||
# Handle both single objective (string) and multi-objective (list) cases
|
||||
if isinstance(self._objective_metric.title, list):
|
||||
self.metric_title = self._objective_metric.title[0] # Use first objective
|
||||
else:
|
||||
self.metric_title = self._objective_metric.title
|
||||
|
||||
if isinstance(self._objective_metric.series, list):
|
||||
self.metric_series = self._objective_metric.series[0] # Use first series
|
||||
else:
|
||||
self.metric_series = self._objective_metric.series
|
||||
|
||||
# ClearML Objective stores sign as a list, e.g., ['max'] or ['min']
|
||||
objective_sign = getattr(self._objective_metric, "sign", None) or getattr(self._objective_metric, "order", None)
|
||||
|
||||
# Handle list case - extract first element
|
||||
if isinstance(objective_sign, list):
|
||||
objective_sign = objective_sign[0] if objective_sign else "max"
|
||||
|
||||
# Default to max if nothing found
|
||||
if objective_sign is None:
|
||||
objective_sign = "max"
|
||||
|
||||
self.maximize_metric = str(objective_sign).lower() in ("max", "max_global")
|
||||
|
||||
# Build ConfigSpace
|
||||
self.config_space = ConfigurationSpace(seed=42)
|
||||
for p in self._hyper_parameters: # Access private attribute correctly
|
||||
cs_hp = _convert_param_to_cs(p)
|
||||
self.config_space.add(cs_hp)
|
||||
|
||||
# Configure SMAC Scenario
|
||||
scenario = Scenario(
|
||||
configspace=self.config_space,
|
||||
n_trials=self.total_max_jobs,
|
||||
min_budget=float(self.min_iterations),
|
||||
max_budget=float(self.max_iterations),
|
||||
walltime_limit=(self.compute_time_limit * 60) if self.compute_time_limit else None,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
# Configurable budget parameter name
|
||||
# Default: Hydra/training.total_timesteps (RL-Framework convention)
|
||||
self.budget_param_name = smac_kwargs.pop(
|
||||
"budget_param_name", "Hydra/training.total_timesteps"
|
||||
)
|
||||
|
||||
# Pop our custom kwargs BEFORE passing smac_kwargs to SuccessiveHalving
|
||||
self.max_consecutive_failures = int(
|
||||
smac_kwargs.pop("max_consecutive_failures", 3)
|
||||
)
|
||||
self._consecutive_failures = 0
|
||||
|
||||
# build the Successive Halving intensifier (NOT Hyperband!)
|
||||
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
||||
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
||||
# eta controls the reduction factor (default 3 means keep top 1/3 each round)
|
||||
# eta can be overridden via smac_kwargs from HyperParameterOptimizer
|
||||
eta = smac_kwargs.pop("eta", 3) # Default to 3 if not specified
|
||||
intensifier = SuccessiveHalving(scenario=scenario, eta=eta, **smac_kwargs)
|
||||
|
||||
# now pass that intensifier instance into the facade
|
||||
self.smac = MultiFidelityFacade(
|
||||
scenario=scenario,
|
||||
target_function=lambda config, budget, seed: 0.0,
|
||||
intensifier=intensifier,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Bookkeeping
|
||||
self.running_tasks = {} # task_id -> trial info
|
||||
self.task_start_times = {} # task_id -> start time (for timeout)
|
||||
self.completed_results = []
|
||||
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
|
||||
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
|
||||
|
||||
# Checkpoint continuation tracking: config_key -> {budget: task_id}
|
||||
# Used to find the previous task's checkpoint when promoting a config
|
||||
self.config_to_tasks = {} # config_key -> {budget: task_id}
|
||||
|
||||
# Manual Successive Halving control
|
||||
self.eta = eta
|
||||
self.current_budget = float(self.min_iterations)
|
||||
self.configs_at_budget = {} # budget -> list of (config, score, trial)
|
||||
self.pending_configs = [] # configs waiting to be evaluated at current_budget - list of (trial, prev_task_id)
|
||||
self.evaluated_at_budget = [] # (config, score, trial, task_id) for current budget
|
||||
self.smac_asked_configs = set() # track which configs SMAC has given us
|
||||
|
||||
# Calculate initial rung size for proper Successive Halving
|
||||
# With eta=3: rung sizes are n, n/3, n/9, ...
|
||||
# Total trials = n * (1 + 1/eta + 1/eta^2 + ...) = n * eta/(eta-1) for infinite series
|
||||
# For finite rungs, calculate exactly
|
||||
num_rungs = 1
|
||||
b = float(self.min_iterations)
|
||||
while b * eta <= self.max_iterations:
|
||||
num_rungs += 1
|
||||
b *= eta
|
||||
|
||||
# Sum of geometric series: 1 + 1/eta + 1/eta^2 + ... (num_rungs terms)
|
||||
series_sum = sum(1.0 / (eta**i) for i in range(num_rungs))
|
||||
self.initial_rung_size = int(self.total_max_jobs / series_sum)
|
||||
self.initial_rung_size = max(self.initial_rung_size, self.num_concurrent_workers) # at least num_workers
|
||||
self.configs_needed_for_rung = self.initial_rung_size # how many configs we still need for current rung
|
||||
self.rung_closed = False # whether we've collected all configs for current rung
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||
def _get_task_safe(self, task_id: str):
|
||||
"""Safely get a task with retry logic."""
|
||||
return Task.get_task(task_id=task_id)
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||
def _launch_task(self, config: dict, budget: float, prev_task_id: str | None = None):
|
||||
"""Launch a task with retry logic for robustness.
|
||||
|
||||
Args:
|
||||
config: Hyperparameter configuration dict
|
||||
budget: Number of epochs to train
|
||||
prev_task_id: Optional task ID from previous budget to continue from (checkpoint)
|
||||
"""
|
||||
base = self._get_task_safe(task_id=self._base_task_id)
|
||||
if base is None:
|
||||
return None
|
||||
|
||||
clone = Task.clone(
|
||||
source_task=base,
|
||||
name=f"HPO Trial - {base.name}",
|
||||
parent=Task.current_task().id, # Set the current HPO task as parent
|
||||
)
|
||||
# Override hyperparameters
|
||||
for k, v in config.items():
|
||||
# Decode parameter name back to original (with slashes)
|
||||
original_name = _decode_param_name(k)
|
||||
# Convert numpy types to Python built-in types
|
||||
if hasattr(v, "item"): # numpy scalar
|
||||
param_value = v.item()
|
||||
elif isinstance(v, int | float | str | bool):
|
||||
param_value = type(v)(v) # Ensure it's the built-in type
|
||||
else:
|
||||
param_value = v
|
||||
clone.set_parameter(original_name, param_value)
|
||||
# Override budget parameter (e.g. total_timesteps) for multi-fidelity
|
||||
if self.max_iterations != self.min_iterations:
|
||||
clone.set_parameter(self.budget_param_name, int(budget))
|
||||
else:
|
||||
clone.set_parameter(self.budget_param_name, int(self.max_iterations))
|
||||
|
||||
# If we have a previous task, pass its ID so the worker can download the checkpoint
|
||||
if prev_task_id:
|
||||
clone.set_parameter("Hydra/training.resume_from_task_id", prev_task_id)
|
||||
|
||||
Task.enqueue(task=clone, queue_name=self.execution_queue)
|
||||
# Track start time for timeout enforcement
|
||||
self.task_start_times[clone.id] = time.time()
|
||||
return clone
|
||||
|
||||
def start(self):
|
||||
controller = Task.current_task()
|
||||
total_launched = 0
|
||||
|
||||
# Keep launching & collecting until budget exhausted
|
||||
while total_launched < self.total_max_jobs:
|
||||
# Check if current budget rung is complete BEFORE asking for new trials
|
||||
# (no running tasks, no pending configs, and we have results for this budget)
|
||||
if not self.running_tasks and not self.pending_configs and self.evaluated_at_budget:
|
||||
# Rung complete! Promote top performers to next budget
|
||||
|
||||
# Store results for this budget
|
||||
self.configs_at_budget[self.current_budget] = self.evaluated_at_budget.copy()
|
||||
|
||||
# Sort by score (best first)
|
||||
sorted_configs = sorted(
|
||||
self.evaluated_at_budget,
|
||||
key=lambda x: x[1], # score
|
||||
reverse=self.maximize_metric,
|
||||
)
|
||||
|
||||
# Print rung results
|
||||
for _i, (_cfg, _score, _tri, _task_id) in enumerate(sorted_configs[:5], 1):
|
||||
pass
|
||||
|
||||
# Move to next budget?
|
||||
next_budget = self.current_budget * self.eta
|
||||
if next_budget <= self.max_iterations:
|
||||
# How many to promote (top 1/eta)
|
||||
n_promote = max(1, len(sorted_configs) // self.eta)
|
||||
promoted = sorted_configs[:n_promote]
|
||||
|
||||
# Update budget and reset for next rung
|
||||
self.current_budget = next_budget
|
||||
self.evaluated_at_budget = []
|
||||
self.configs_needed_for_rung = 0 # promoted configs are all we need
|
||||
self.rung_closed = True # rung is pre-filled with promoted configs
|
||||
|
||||
# Re-queue promoted configs with new budget
|
||||
# Include the previous task ID for checkpoint continuation
|
||||
for _cfg, _score, old_trial, prev_task_id in promoted:
|
||||
new_trial = TrialInfo(
|
||||
config=old_trial.config,
|
||||
instance=old_trial.instance,
|
||||
seed=old_trial.seed,
|
||||
budget=self.current_budget,
|
||||
)
|
||||
# Store as tuple: (trial, prev_task_id)
|
||||
self.pending_configs.append((new_trial, prev_task_id))
|
||||
else:
|
||||
# All budgets complete
|
||||
break
|
||||
|
||||
# Fill pending_configs with new trials ONLY if we haven't closed this rung yet
|
||||
# For the first rung: ask SMAC for initial_rung_size configs total
|
||||
# For subsequent rungs: only use promoted configs (rung is already closed)
|
||||
while (
|
||||
not self.rung_closed
|
||||
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||
< self.initial_rung_size
|
||||
and total_launched < self.total_max_jobs
|
||||
):
|
||||
trial = self.smac.ask()
|
||||
if trial is None:
|
||||
self.rung_closed = True
|
||||
break
|
||||
# Create new trial with forced budget (TrialInfo is frozen, can't modify)
|
||||
trial_with_budget = TrialInfo(
|
||||
config=trial.config,
|
||||
instance=trial.instance,
|
||||
seed=trial.seed,
|
||||
budget=self.current_budget,
|
||||
)
|
||||
cfg_key = str(sorted(trial.config.items()))
|
||||
if cfg_key not in self.smac_asked_configs:
|
||||
self.smac_asked_configs.add(cfg_key)
|
||||
# Store as tuple: (trial, None) - no previous task for new configs
|
||||
self.pending_configs.append((trial_with_budget, None))
|
||||
|
||||
# Check if we've collected enough configs for this rung
|
||||
if (
|
||||
not self.rung_closed
|
||||
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||
>= self.initial_rung_size
|
||||
):
|
||||
self.rung_closed = True
|
||||
|
||||
# Launch pending configs up to concurrent limit
|
||||
while self.pending_configs and len(self.running_tasks) < self.num_concurrent_workers:
|
||||
# Unpack tuple: (trial, prev_task_id)
|
||||
trial, prev_task_id = self.pending_configs.pop(0)
|
||||
t = self._launch_task(trial.config, self.current_budget, prev_task_id=prev_task_id)
|
||||
if t is None:
|
||||
# Launch failed, mark trial as failed and continue
|
||||
# Tell SMAC this trial failed with worst possible score
|
||||
cost = float("inf") if self.maximize_metric else float("-inf")
|
||||
self.smac.tell(trial, TrialValue(cost=cost))
|
||||
total_launched += 1
|
||||
continue
|
||||
self.running_tasks[t.id] = trial
|
||||
|
||||
# Track which task ID was used for this config at this budget
|
||||
cfg_key = str(sorted(trial.config.items()))
|
||||
if cfg_key not in self.config_to_tasks:
|
||||
self.config_to_tasks[cfg_key] = {}
|
||||
self.config_to_tasks[cfg_key][self.current_budget] = t.id
|
||||
|
||||
total_launched += 1
|
||||
|
||||
if not self.running_tasks and not self.pending_configs:
|
||||
break
|
||||
|
||||
# Abort if too many consecutive trials failed (likely a config bug)
|
||||
if (
|
||||
self.max_consecutive_failures > 0
|
||||
and self._consecutive_failures >= self.max_consecutive_failures
|
||||
):
|
||||
controller.get_logger().report_text(
|
||||
f"ABORTING: {self._consecutive_failures} consecutive trial "
|
||||
f"failures (limit: {self.max_consecutive_failures}). "
|
||||
"Check the trial logs for errors."
|
||||
)
|
||||
# Stop any still-running tasks
|
||||
for tid in list(self.running_tasks):
|
||||
with contextlib.suppress(Exception):
|
||||
t = self._get_task_safe(task_id=tid)
|
||||
if t:
|
||||
t.mark_stopped(force=True)
|
||||
self.running_tasks.clear()
|
||||
break
|
||||
|
||||
# Poll for finished or timed out
|
||||
done = []
|
||||
timed_out = []
|
||||
failed_to_check = []
|
||||
for tid, _tri in self.running_tasks.items():
|
||||
try:
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
if task is None:
|
||||
failed_to_check.append(tid)
|
||||
continue
|
||||
|
||||
st = task.get_status()
|
||||
|
||||
# Check if task completed normally
|
||||
if st == Task.TaskStatusEnum.completed or st in (
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.stopped,
|
||||
):
|
||||
done.append(tid)
|
||||
# Check for timeout (if time limit is set)
|
||||
elif self.time_limit_per_job and tid in self.task_start_times:
|
||||
elapsed_minutes = (time.time() - self.task_start_times[tid]) / 60.0
|
||||
if elapsed_minutes > self.time_limit_per_job:
|
||||
with contextlib.suppress(Exception):
|
||||
task.mark_stopped(force=True)
|
||||
timed_out.append(tid)
|
||||
except Exception:
|
||||
# Don't mark as failed immediately, might be transient
|
||||
# Only mark failed after multiple consecutive failures
|
||||
if not hasattr(self, "_task_check_failures"):
|
||||
self._task_check_failures = {}
|
||||
self._task_check_failures[tid] = self._task_check_failures.get(tid, 0) + 1
|
||||
if self._task_check_failures[tid] >= 5: # 5 consecutive failures
|
||||
failed_to_check.append(tid)
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
# Process tasks that failed to check
|
||||
for tid in failed_to_check:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
# Tell SMAC this trial failed with worst possible score
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
"failed": True,
|
||||
}
|
||||
)
|
||||
# Store result with task_id for checkpoint tracking
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
# Process completed tasks
|
||||
for tid in done:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
|
||||
# Clear any accumulated failures for this task
|
||||
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
|
||||
# Detect hard-failed tasks (crashed / errored) vs completed
|
||||
task_failed = False
|
||||
if task is not None:
|
||||
st = task.get_status()
|
||||
task_failed = st in (
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.stopped,
|
||||
)
|
||||
|
||||
if task is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
task_failed = True
|
||||
else:
|
||||
res = self._get_objective(task)
|
||||
|
||||
if res is None or res == float("-inf") or res == float("inf"):
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
|
||||
# Track consecutive failures for abort logic
|
||||
if task_failed:
|
||||
self._consecutive_failures += 1
|
||||
else:
|
||||
self._consecutive_failures = 0 # reset on any success
|
||||
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
}
|
||||
)
|
||||
|
||||
# Store result for this budget rung with task_id for checkpoint tracking
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
iteration = len(self.completed_results)
|
||||
|
||||
# Always report the trial score (even if it's bad)
|
||||
if res is not None and res != float("-inf") and res != float("inf"):
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization", series="trial_score", value=res, iteration=iteration
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization",
|
||||
series="trial_budget",
|
||||
value=tri.budget or self.max_iterations,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
# Update best score tracking based on actual results
|
||||
if res is not None and res != float("-inf") and res != float("inf"):
|
||||
if self.maximize_metric:
|
||||
self.best_score_so_far = max(self.best_score_so_far, res)
|
||||
elif res < self.best_score_so_far:
|
||||
self.best_score_so_far = res
|
||||
|
||||
# Always report best score so far (shows flat line when no improvement)
|
||||
if self.best_score_so_far != float("-inf") and self.best_score_so_far != float("inf"):
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization", series="best_score", value=self.best_score_so_far, iteration=iteration
|
||||
)
|
||||
|
||||
# Report running statistics
|
||||
valid_scores = [
|
||||
r["value"]
|
||||
for r in self.completed_results
|
||||
if r["value"] is not None and r["value"] != float("-inf") and r["value"] != float("inf")
|
||||
]
|
||||
if valid_scores:
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization",
|
||||
series="mean_score",
|
||||
value=sum(valid_scores) / len(valid_scores),
|
||||
iteration=iteration,
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Progress",
|
||||
series="completed_trials",
|
||||
value=len(self.completed_results),
|
||||
iteration=iteration,
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Progress", series="running_tasks", value=len(self.running_tasks), iteration=iteration
|
||||
)
|
||||
|
||||
# Process timed out tasks (treat as failed with current objective value)
|
||||
for tid in timed_out:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
|
||||
# Clear any accumulated failures for this task
|
||||
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
# Try to get the last objective value before timeout
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
if task is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
else:
|
||||
res = self._get_objective(task)
|
||||
|
||||
if res is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
"timed_out": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Store timed out result for this budget rung with task_id
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
time.sleep(self.pool_period_minutes * 60) # Fix: use correct attribute name from base class
|
||||
if self.compute_time_limit and controller.get_runtime() > self.compute_time_limit * 60:
|
||||
break
|
||||
|
||||
# Finalize
|
||||
self._finalize()
|
||||
return self.completed_results
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=2.0, exceptions=(SendError, ConnectionError, KeyError))
|
||||
def _get_objective(self, task: Task):
|
||||
"""Get objective metric value with retry logic for robustness."""
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
m = task.get_last_scalar_metrics()
|
||||
if not m:
|
||||
return None
|
||||
|
||||
metric_data = m[self.metric_title][self.metric_series]
|
||||
|
||||
# ClearML returns dict with 'last', 'min', 'max' keys representing
|
||||
# the last/min/max values of this series over ALL logged iterations.
|
||||
# For snake_length/train_max: 'last' is the last logged train_max value,
|
||||
# 'max' is the highest train_max ever logged during training.
|
||||
|
||||
# Use 'max' if maximizing (we want the best performance achieved),
|
||||
# 'min' if minimizing, fallback to 'last'
|
||||
if self.maximize_metric and "max" in metric_data:
|
||||
result = metric_data["max"]
|
||||
elif not self.maximize_metric and "min" in metric_data:
|
||||
result = metric_data["min"]
|
||||
else:
|
||||
result = metric_data["last"]
|
||||
return result
|
||||
except (KeyError, Exception):
|
||||
return None
|
||||
|
||||
def _finalize(self):
|
||||
controller = Task.current_task()
|
||||
# Report final best score
|
||||
controller.get_logger().report_text(f"Final best score: {self.best_score_so_far}")
|
||||
|
||||
# Also try to get SMAC's incumbent for comparison
|
||||
try:
|
||||
incumbent = self.smac.intensifier.get_incumbent()
|
||||
if incumbent is not None:
|
||||
runhistory = self.smac.runhistory
|
||||
# Try different ways to get the cost
|
||||
incumbent_cost = None
|
||||
try:
|
||||
incumbent_cost = runhistory.get_cost(incumbent)
|
||||
except Exception:
|
||||
# Fallback: search through runhistory manually
|
||||
for trial_key, trial_value in runhistory.items():
|
||||
trial_config = runhistory.get_config(trial_key.config_id)
|
||||
if trial_config == incumbent and (incumbent_cost is None or trial_value.cost < incumbent_cost):
|
||||
incumbent_cost = trial_value.cost
|
||||
|
||||
if incumbent_cost is not None:
|
||||
score = -incumbent_cost if self.maximize_metric else incumbent_cost
|
||||
controller.get_logger().report_text(f"SMAC incumbent: {incumbent}, score: {score}")
|
||||
controller.upload_artifact(
|
||||
"best_config",
|
||||
{"config": dict(incumbent), "score": score, "our_best_score": self.best_score_so_far},
|
||||
)
|
||||
else:
|
||||
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||
except Exception as e:
|
||||
controller.get_logger().report_text(f"Error getting SMAC incumbent: {e}")
|
||||
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||
@@ -3,14 +3,95 @@ import torch.nn as nn
|
||||
from gymnasium import spaces
|
||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||
|
||||
|
||||
class HistoryEncoder(nn.Module):
|
||||
"""1D-CNN encoder over a temporal window of (obs, action) pairs.
|
||||
|
||||
Input: (batch, history_length, step_dim)
|
||||
Output: (batch, embedding_dim)
|
||||
|
||||
Architecture: two temporal conv layers → global average pool → linear.
|
||||
Lets the policy implicitly infer the current dynamics (friction, torque
|
||||
scale, latency, …) from how the system responded to recent actions —
|
||||
end-to-end adaptation when trained under domain randomization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_length: int,
|
||||
step_dim: int,
|
||||
embedding_dim: int = 32,
|
||||
hidden_channels: int = 32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
# (batch, step_dim, history_length) after transpose
|
||||
nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.fc = nn.Linear(hidden_channels, embedding_dim)
|
||||
|
||||
def forward(self, history: torch.Tensor) -> torch.Tensor:
|
||||
"""history: (batch, history_length, step_dim)."""
|
||||
# Conv1d expects (batch, channels, seq_len).
|
||||
x = history.transpose(1, 2)
|
||||
x = self.conv(x)
|
||||
# Global average pool over time.
|
||||
x = x.mean(dim=-1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
"""Shared policy/value network with an optional history encoder.
|
||||
|
||||
With ``history_length > 0`` the input states are expected to be
|
||||
``[raw_obs, history_flat]`` (as produced by ``BaseRunner``); the history
|
||||
window is compressed by a :class:`HistoryEncoder` and concatenated with
|
||||
the raw observation before the shared backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
device: torch.device,
|
||||
hidden_sizes: tuple[int, ...] = (32, 32),
|
||||
clip_actions: bool = False,
|
||||
clip_log_std: bool = True,
|
||||
min_log_std: float = -2.0,
|
||||
max_log_std: float = 2.0,
|
||||
initial_log_std: float = 0.0,
|
||||
# ── History encoder ──────────────────────────────────────
|
||||
history_length: int = 0,
|
||||
raw_obs_dim: int = 0,
|
||||
embedding_dim: int = 32,
|
||||
):
|
||||
Model.__init__(self, observation_space, action_space, device)
|
||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||
GaussianMixin.__init__(
|
||||
self, clip_actions, clip_log_std, min_log_std, max_log_std,
|
||||
)
|
||||
DeterministicMixin.__init__(self, clip_actions)
|
||||
|
||||
layers = []
|
||||
in_dim: int = self.num_observations
|
||||
self._history_length = history_length
|
||||
self._raw_obs_dim = raw_obs_dim
|
||||
self._embedding_dim = embedding_dim
|
||||
|
||||
self.history_encoder: HistoryEncoder | None = None
|
||||
if history_length > 0 and raw_obs_dim > 0:
|
||||
step_dim = raw_obs_dim + self.num_actions
|
||||
self.history_encoder = HistoryEncoder(
|
||||
history_length=history_length,
|
||||
step_dim=step_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
in_dim = raw_obs_dim + embedding_dim
|
||||
else:
|
||||
in_dim = self.num_observations
|
||||
|
||||
# ── Shared backbone ──────────────────────────────────────
|
||||
layers: list[nn.Module] = []
|
||||
for hidden_size in hidden_sizes:
|
||||
layers.append(nn.Linear(in_dim, hidden_size))
|
||||
layers.append(nn.ELU())
|
||||
@@ -19,30 +100,45 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
|
||||
# Policy head
|
||||
self.mean_layer = nn.Linear(in_dim, self.num_actions)
|
||||
self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std))
|
||||
self.log_std_parameter: nn.Parameter = nn.Parameter(
|
||||
torch.full((self.num_actions,), initial_log_std),
|
||||
)
|
||||
|
||||
# Value head
|
||||
self.value_layer = nn.Linear(in_dim, 1)
|
||||
self._shared_output: torch.Tensor | None = None
|
||||
|
||||
|
||||
def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]:
|
||||
def act(
|
||||
self, inputs: dict[str, torch.Tensor], role: str = "",
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if role == "policy":
|
||||
return GaussianMixin.act(self, inputs, role)
|
||||
elif role == "value":
|
||||
return DeterministicMixin.act(self, inputs, role)
|
||||
|
||||
|
||||
def _encode(self, states: torch.Tensor) -> torch.Tensor:
|
||||
"""Optionally split off and encode the history window."""
|
||||
if self.history_encoder is None:
|
||||
return self.net(states)
|
||||
|
||||
obs = states[:, :self._raw_obs_dim]
|
||||
hist_flat = states[:, self._raw_obs_dim:]
|
||||
step_dim = self._raw_obs_dim + self.num_actions
|
||||
history = hist_flat.reshape(-1, self._history_length, step_dim)
|
||||
embedding = self.history_encoder(history)
|
||||
return self.net(torch.cat([obs, embedding], dim=-1))
|
||||
|
||||
def compute(
|
||||
self, inputs: dict[str, torch.Tensor], role: str = ""
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
self, inputs: dict[str, torch.Tensor], role: str = "",
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if role == "policy":
|
||||
self._shared_output = self.net(inputs["states"])
|
||||
self._shared_output = self._encode(inputs["states"])
|
||||
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
|
||||
elif role == "value":
|
||||
shared_output = (
|
||||
self._shared_output
|
||||
if self._shared_output is not None
|
||||
else self.net(inputs["states"])
|
||||
else self._encode(inputs["states"])
|
||||
)
|
||||
self._shared_output = None
|
||||
return self.value_layer(shared_output), {}
|
||||
return self.value_layer(shared_output), {}
|
||||
|
||||
354
src/runners/mjx.py
Normal file
354
src/runners/mjx.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""GPU-batched MuJoCo simulation using MJX (JAX backend).
|
||||
|
||||
MJX runs all environments in parallel on GPU via JAX, providing
|
||||
~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+).
|
||||
|
||||
Requirements:
|
||||
pip install 'jax[cuda12]' # NVIDIA GPU
|
||||
pip install jax # CPU fallback
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
|
||||
import structlog
|
||||
import torch
|
||||
|
||||
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
|
||||
# default JAX preallocates ~75% of GPU memory on init, starving torch and
|
||||
# causing OOM at the first PPO update. Disable preallocation so both libraries
|
||||
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
|
||||
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mujoco
|
||||
from mujoco import mjx
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"MJX runner requires JAX and MuJoCo MJX. Install with:\n"
|
||||
" pip install 'jax[cuda12]' # GPU\n"
|
||||
" pip install jax # CPU\n"
|
||||
) from e
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
from src.runners.mujoco import (
|
||||
ActuatorLimits,
|
||||
load_mujoco_model,
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MJXRunnerConfig(BaseRunnerConfig):
|
||||
num_envs: int = 1024
|
||||
device: str = "cuda"
|
||||
dt: float = 0.002
|
||||
substeps: int = 10
|
||||
|
||||
|
||||
class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
"""GPU-batched MuJoCo runner using MJX (JAX).
|
||||
|
||||
Physics runs entirely on GPU via JAX; observations flow to
|
||||
PyTorch through zero-copy DLPack transfers.
|
||||
"""
|
||||
|
||||
def __init__(self, env: BaseEnv, config: MJXRunnerConfig):
|
||||
super().__init__(env, config)
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return self.config.num_envs
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(self.config.device)
|
||||
|
||||
# ── Initialization ───────────────────────────────────────────────
|
||||
|
||||
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
|
||||
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
|
||||
self._mj_model = load_mujoco_model(self.env.robot)
|
||||
self._mj_model.opt.timestep = config.dt
|
||||
self._nq = self._mj_model.nq
|
||||
self._nv = self._mj_model.nv
|
||||
self._nu = self._mj_model.nu
|
||||
|
||||
# Step 2: Put model on GPU
|
||||
self._mjx_model = mjx.put_model(self._mj_model)
|
||||
|
||||
# Step 3: Default reset state on GPU
|
||||
default_data = mujoco.MjData(self._mj_model)
|
||||
mujoco.mj_forward(self._mj_model, default_data)
|
||||
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
|
||||
|
||||
# Env-defined initial-state distribution (shared with the CPU
|
||||
# runner) — baked into the JIT reset as constants.
|
||||
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
|
||||
self._nq, self._nv,
|
||||
)
|
||||
self._init_qpos_lo = jnp.asarray(qpos_lo)
|
||||
self._init_qpos_hi = jnp.asarray(qpos_hi)
|
||||
self._init_qvel_lo = jnp.asarray(qvel_lo)
|
||||
self._init_qvel_hi = jnp.asarray(qvel_hi)
|
||||
|
||||
# Step 4: Initialise all environments with randomized states
|
||||
self._rng = jax.random.PRNGKey(42)
|
||||
self._batch_data = self._make_batched_data(config.num_envs)
|
||||
|
||||
# Step 4b: Build motor model info (ctrl_idx, qvel_idx, ActuatorConfig)
|
||||
self._motor_info: list[tuple[int, int]] = []
|
||||
self._motor_acts: list = []
|
||||
for ctrl_idx, act in enumerate(self.env.robot.actuators):
|
||||
if act.has_motor_model:
|
||||
jnt_id = mujoco.mj_name2id(
|
||||
self._mj_model, mujoco.mjtObj.mjOBJ_JOINT, act.joint,
|
||||
)
|
||||
qvel_idx = self._mj_model.jnt_dofadr[jnt_id]
|
||||
self._motor_info.append((ctrl_idx, qvel_idx))
|
||||
self._motor_acts.append(act)
|
||||
|
||||
# Step 5: JIT-compile the hot-path functions
|
||||
self._compile_jit_fns(config.substeps)
|
||||
|
||||
# Keep one CPU MjData for offscreen rendering
|
||||
self._render_data = mujoco.MjData(self._mj_model)
|
||||
|
||||
# Per-env DR scale arrays (synced from torch on every reset).
|
||||
# Initialised to 1.0 here because _setup_domain_rand runs after this.
|
||||
self._mjx_fr = jnp.ones(config.num_envs)
|
||||
self._mjx_dp = jnp.ones(config.num_envs)
|
||||
self._mjx_tq = jnp.ones(config.num_envs)
|
||||
|
||||
log.info(
|
||||
"mjx_runner_ready",
|
||||
num_envs=config.num_envs,
|
||||
substeps=config.substeps,
|
||||
jax_devices=str(jax.devices()),
|
||||
)
|
||||
|
||||
def _make_batched_data(self, n: int):
|
||||
"""Create *n* environments with env-defined initial randomization."""
|
||||
self._rng, k1, k2 = jax.random.split(self._rng, 3)
|
||||
pq = jax.random.uniform(
|
||||
k1, (n, self._nq),
|
||||
minval=self._init_qpos_lo, maxval=self._init_qpos_hi,
|
||||
)
|
||||
pv = jax.random.uniform(
|
||||
k2, (n, self._nv),
|
||||
minval=self._init_qvel_lo, maxval=self._init_qvel_hi,
|
||||
)
|
||||
|
||||
default = self._default_mjx_data
|
||||
model = self._mjx_model
|
||||
|
||||
def _init_one(pq_i, pv_i):
|
||||
d = default.replace(
|
||||
qpos=default.qpos + pq_i,
|
||||
qvel=default.qvel + pv_i,
|
||||
)
|
||||
return mjx.forward(model, d)
|
||||
|
||||
return jax.vmap(_init_one)(pq, pv)
|
||||
|
||||
def _compile_jit_fns(self, substeps: int) -> None:
|
||||
"""Pre-compile the two hot-path functions so the first call is fast."""
|
||||
model = self._mjx_model
|
||||
default = self._default_mjx_data
|
||||
|
||||
lim = ActuatorLimits(self._mj_model)
|
||||
act_jnt_ids = jnp.array(lim.jnt_ids)
|
||||
act_limited = jnp.array(lim.limited)
|
||||
act_lo = jnp.array(lim.lo)
|
||||
act_hi = jnp.array(lim.hi)
|
||||
act_gs = jnp.array(lim.gear_sign)
|
||||
|
||||
# ── Motor model params (JAX arrays for JIT) ─────────────────
|
||||
# Must stay in lock-step with ActuatorConfig.transform_ctrl() /
|
||||
# compute_motor_force() in src/core/robot.py — sysid fits against
|
||||
# the CPU implementation.
|
||||
_has_motor = len(self._motor_info) > 0
|
||||
if _has_motor:
|
||||
acts = self._motor_acts
|
||||
_ctrl_ids = jnp.array([c for c, _ in self._motor_info])
|
||||
_qvel_ids = jnp.array([q for _, q in self._motor_info])
|
||||
_ctrl_lo = jnp.array([a.ctrl_range[0] for a in acts])
|
||||
_ctrl_hi = jnp.array([a.ctrl_range[1] for a in acts])
|
||||
_bias = jnp.array([a.action_bias for a in acts])
|
||||
_dz_pos = jnp.array([a.deadzone[0] for a in acts])
|
||||
_dz_neg = jnp.array([a.deadzone[1] for a in acts])
|
||||
_gear_pos = jnp.array([a.gear[0] for a in acts])
|
||||
_gear_neg = jnp.array([a.gear[1] for a in acts])
|
||||
_gear_avg = jnp.array([a.gear_avg for a in acts])
|
||||
_fl_pos = jnp.array([a.frictionloss[0] for a in acts])
|
||||
_fl_neg = jnp.array([a.frictionloss[1] for a in acts])
|
||||
_strb_boost = jnp.array([a.stribeck_friction_boost for a in acts])
|
||||
_strb_vel = jnp.array([a.stribeck_vel for a in acts])
|
||||
_damp_pos = jnp.array([a.damping[0] for a in acts])
|
||||
_damp_neg = jnp.array([a.damping[1] for a in acts])
|
||||
_visc_quad = jnp.array([a.viscous_quadratic for a in acts])
|
||||
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
||||
|
||||
# ── Batched step (N substeps per call) ──────────────────────
|
||||
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
|
||||
# multipliers (1.0 = off). Passed as args (not closure constants) so
|
||||
# resampling them every episode does NOT trigger JIT recompilation.
|
||||
@jax.jit
|
||||
def step_fn(data, fr_scale, dp_scale, tq_scale):
|
||||
fr = fr_scale[:, None] # broadcast over motor actuators
|
||||
dp = dp_scale[:, None]
|
||||
tq = tq_scale[:, None]
|
||||
|
||||
# Software limit switch: clamp ctrl once before substeps.
|
||||
pos = data.qpos[:, act_jnt_ids]
|
||||
ctrl = data.ctrl
|
||||
at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0)
|
||||
at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0)
|
||||
ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl)
|
||||
|
||||
if _has_motor:
|
||||
# Clip → bias → deadzone → asymmetric gear compensation
|
||||
# (same order as ActuatorConfig.transform_ctrl).
|
||||
mc = ctrl[:, _ctrl_ids]
|
||||
mc = jnp.clip(mc, _ctrl_lo, _ctrl_hi)
|
||||
mc = mc + _bias
|
||||
mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc)
|
||||
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
||||
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
||||
mc = mc * gear_dir / _gear_avg
|
||||
mc = mc * tq # torque_scale (DR)
|
||||
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
||||
|
||||
data = data.replace(ctrl=ctrl)
|
||||
|
||||
def body(_, d):
|
||||
if _has_motor:
|
||||
vel = d.qvel[:, _qvel_ids]
|
||||
mc = d.ctrl[:, _ctrl_ids]
|
||||
|
||||
# Coulomb + Stribeck friction (direction-dependent) × DR
|
||||
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
|
||||
fl = fl + _strb_boost * jnp.exp(
|
||||
-((jnp.abs(vel) / _strb_vel) ** 2)
|
||||
)
|
||||
fl = fl * fr
|
||||
torque = -jnp.where(
|
||||
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
||||
)
|
||||
# Viscous damping (direction-dependent) × DR scale
|
||||
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
|
||||
torque = torque - damp * vel
|
||||
# Quadratic velocity drag
|
||||
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
||||
# Back-EMF torque reduction
|
||||
bemf = _back_emf * vel * jnp.sign(mc)
|
||||
torque = torque - jnp.where(
|
||||
jnp.abs(mc) > 1e-6, bemf, 0.0,
|
||||
)
|
||||
torque = jnp.clip(torque, -10.0, 10.0)
|
||||
d = d.replace(
|
||||
qfrc_applied=d.qfrc_applied.at[:, _qvel_ids].set(torque),
|
||||
)
|
||||
|
||||
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
|
||||
|
||||
return jax.lax.fori_loop(0, substeps, body, data)
|
||||
|
||||
self._jit_step = step_fn
|
||||
|
||||
# ── Selective reset ─────────────────────────────────────────
|
||||
init_qpos_lo = self._init_qpos_lo
|
||||
init_qpos_hi = self._init_qpos_hi
|
||||
init_qvel_lo = self._init_qvel_lo
|
||||
init_qvel_hi = self._init_qvel_hi
|
||||
|
||||
@jax.jit
|
||||
def reset_fn(data, mask, rng):
|
||||
rng, k1, k2 = jax.random.split(rng, 3)
|
||||
ne = data.qpos.shape[0]
|
||||
|
||||
pq = jax.random.uniform(
|
||||
k1, (ne, default.qpos.shape[0]),
|
||||
minval=init_qpos_lo, maxval=init_qpos_hi,
|
||||
)
|
||||
pv = jax.random.uniform(
|
||||
k2, (ne, default.qvel.shape[0]),
|
||||
minval=init_qvel_lo, maxval=init_qvel_hi,
|
||||
)
|
||||
|
||||
m = mask[:, None] # (num_envs, 1) broadcast helper
|
||||
|
||||
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
|
||||
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
|
||||
new_ctrl = jnp.where(m, 0.0, data.ctrl)
|
||||
|
||||
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
|
||||
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
|
||||
|
||||
return new_data, rng
|
||||
|
||||
self._jit_reset = reset_fn
|
||||
|
||||
# ── Step / Reset ─────────────────────────────────────────────────
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||
|
||||
# Set ctrl & run N substeps for all environments (with per-env DR scales)
|
||||
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||
self._batch_data = self._jit_step(
|
||||
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
|
||||
)
|
||||
|
||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
|
||||
return qpos, qvel
|
||||
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Build boolean mask (fixed shape → no JIT recompilation)
|
||||
mask = torch.zeros(
|
||||
self.config.num_envs, dtype=torch.bool, device=self.device,
|
||||
)
|
||||
mask[env_ids] = True
|
||||
mask_jax = jnp.from_dlpack(mask)
|
||||
|
||||
self._batch_data, self._rng = self._jit_reset(
|
||||
self._batch_data, mask_jax, self._rng,
|
||||
)
|
||||
|
||||
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
|
||||
# resamples self._dr_scales just before this call, so re-deriving the
|
||||
# full arrays here keeps the JAX copies current for every env.
|
||||
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
|
||||
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
|
||||
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
|
||||
|
||||
# Return the FULL batch (BaseRunner indexes the reset envs in torch)
|
||||
# — avoids a GPU→CPU sync + JAX gather on every step with a done env.
|
||||
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
|
||||
return qpos, qvel
|
||||
|
||||
# ── Rendering ────────────────────────────────────────────────────
|
||||
|
||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Offscreen render — copies one env's state from GPU to CPU."""
|
||||
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
|
||||
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
|
||||
self._render_data.ctrl[:] = np.asarray(self._batch_data.ctrl[env_idx])
|
||||
mujoco.mj_forward(self._mj_model, self._render_data)
|
||||
|
||||
if not hasattr(self, "_offscreen_renderer"):
|
||||
self._offscreen_renderer = mujoco.Renderer(
|
||||
self._mj_model, width=640, height=480,
|
||||
)
|
||||
self._offscreen_renderer.update_scene(self._render_data)
|
||||
return self._offscreen_renderer.render().copy()
|
||||
@@ -1,19 +1,206 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.robot import RobotConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
num_envs: int = 16
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
dt: float = 0.002
|
||||
substeps: int = 10
|
||||
|
||||
|
||||
class ActuatorLimits:
|
||||
"""Software limit-switch: cuts motor ctrl when a joint hits its range.
|
||||
|
||||
The real robot has physical limit switches that kill motor current
|
||||
at the travel endpoints. MuJoCo's built-in joint limits only apply
|
||||
a spring force — they don't zero the actuator signal. This class
|
||||
replicates the hardware behavior.
|
||||
"""
|
||||
|
||||
def __init__(self, model: mujoco.MjModel) -> None:
|
||||
jnt_ids = model.actuator_trnid[:model.nu, 0]
|
||||
self.jnt_ids = jnt_ids
|
||||
self.limited = model.jnt_limited[jnt_ids].astype(bool)
|
||||
self.lo = model.jnt_range[jnt_ids, 0]
|
||||
self.hi = model.jnt_range[jnt_ids, 1]
|
||||
self.gear_sign = np.sign(model.actuator_gear[:model.nu, 0])
|
||||
|
||||
def enforce(self, model: mujoco.MjModel, data: mujoco.MjData) -> None:
|
||||
"""Zero ctrl that would push past joint limits (call every substep)."""
|
||||
if not self.limited.any():
|
||||
return
|
||||
pos = data.qpos[self.jnt_ids]
|
||||
signed_ctrl = self.gear_sign * data.ctrl[:model.nu]
|
||||
at_hi = self.limited & (pos >= self.hi) & (signed_ctrl > 0)
|
||||
at_lo = self.limited & (pos <= self.lo) & (signed_ctrl < 0)
|
||||
data.ctrl[at_hi | at_lo] = 0.0
|
||||
|
||||
|
||||
# ── Public utilities ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_mujoco_model(robot: RobotConfig) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) and apply robot.yaml settings.
|
||||
|
||||
Single model-loading entry point for all MuJoCo-based code:
|
||||
training runners, MJX, and system identification.
|
||||
|
||||
Two-step approach required because MuJoCo's URDF parser ignores
|
||||
``<actuator>`` in the ``<mujoco>`` extension block:
|
||||
|
||||
1. Load the URDF -> MuJoCo converts it to internal MJCF
|
||||
2. Export the MJCF XML, inject actuators + joint overrides, reload
|
||||
|
||||
This keeps the URDF clean (re-exportable from CAD) -- all hardware
|
||||
tuning lives in ``robot.yaml``.
|
||||
"""
|
||||
abs_path = robot.urdf_path.resolve()
|
||||
model_dir = abs_path.parent
|
||||
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||||
|
||||
# -- Step 1: Load URDF with meshdir injection --
|
||||
if is_urdf:
|
||||
tree = ET.parse(abs_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# MuJoCo's URDF parser strips directory prefixes from mesh
|
||||
# filenames, so we inject a <mujoco><compiler meshdir="..."/>
|
||||
# block. The original URDF stays clean and simulator-agnostic.
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
parent = str(Path(fn).parent)
|
||||
if parent and parent != ".":
|
||||
meshdir = parent
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
ET.SubElement(mj_ext, "compiler", attrib={
|
||||
"meshdir": meshdir,
|
||||
"balanceinertia": "true",
|
||||
})
|
||||
|
||||
# Write to a temp file (unique name for multiprocessing safety).
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
suffix=".urdf", prefix="_mj_", dir=str(model_dir),
|
||||
)
|
||||
os.close(fd)
|
||||
try:
|
||||
tree.write(tmp_path, xml_declaration=True, encoding="unicode")
|
||||
model_raw = mujoco.MjModel.from_xml_path(tmp_path)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||||
|
||||
# If robot.yaml has no actuators/joints, we're done.
|
||||
if not robot.actuators and not robot.joints:
|
||||
return model_raw
|
||||
|
||||
# -- Step 2: Export MJCF, inject actuators + joint overrides --
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
suffix=".xml", prefix="_mj_", dir=str(model_dir),
|
||||
)
|
||||
os.close(fd)
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_path, model_raw)
|
||||
mjcf_str = Path(tmp_path).read_text()
|
||||
root = ET.fromstring(mjcf_str)
|
||||
|
||||
# -- Inject actuators --
|
||||
if robot.actuators:
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in robot.actuators:
|
||||
attribs = {
|
||||
"name": f"{act.joint}_{act.type}",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear_avg),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
}
|
||||
|
||||
# dyntype is only available on <general>, not on
|
||||
# shortcut elements like <motor>/<position>/<velocity>.
|
||||
use_general = act.filter_tau > 0
|
||||
|
||||
if use_general:
|
||||
attribs["dyntype"] = "filter"
|
||||
attribs["dynprm"] = str(act.filter_tau)
|
||||
attribs["gaintype"] = "fixed"
|
||||
if act.type == "position":
|
||||
attribs["biastype"] = "affine"
|
||||
attribs["gainprm"] = str(act.kp)
|
||||
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
|
||||
elif act.type == "velocity":
|
||||
attribs["biastype"] = "affine"
|
||||
attribs["gainprm"] = str(act.kp)
|
||||
attribs["biasprm"] = f"0 0 -{act.kp}"
|
||||
else: # motor
|
||||
attribs["biastype"] = "none"
|
||||
ET.SubElement(act_elem, "general", attrib=attribs)
|
||||
else:
|
||||
if act.type == "position":
|
||||
attribs["kp"] = str(act.kp)
|
||||
if act.kv > 0:
|
||||
attribs["kv"] = str(act.kv)
|
||||
elif act.type == "velocity":
|
||||
attribs["kp"] = str(act.kp)
|
||||
ET.SubElement(act_elem, act.type, attrib=attribs)
|
||||
|
||||
# -- Apply joint overrides from robot.yaml --
|
||||
# For actuated joints with a motor model, MuJoCo damping/frictionloss
|
||||
# are set to 0 — the motor model handles them via qfrc_applied.
|
||||
joint_damping: dict[str, float] = {}
|
||||
joint_frictionloss: dict[str, float] = {}
|
||||
for act in robot.actuators:
|
||||
if act.has_motor_model:
|
||||
joint_damping[act.joint] = 0.0
|
||||
joint_frictionloss[act.joint] = 0.0
|
||||
joint_armature: dict[str, float] = {}
|
||||
for name, jcfg in robot.joints.items():
|
||||
if jcfg.damping is not None:
|
||||
joint_damping[name] = jcfg.damping
|
||||
if jcfg.armature is not None:
|
||||
joint_armature[name] = jcfg.armature
|
||||
if jcfg.frictionloss is not None:
|
||||
joint_frictionloss[name] = jcfg.frictionloss
|
||||
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
name = jnt.get("name")
|
||||
if name in joint_damping:
|
||||
jnt.set("damping", str(joint_damping[name]))
|
||||
if name in joint_armature:
|
||||
jnt.set("armature", str(joint_armature[name]))
|
||||
if name in joint_frictionloss:
|
||||
jnt.set("frictionloss", str(joint_frictionloss[name]))
|
||||
|
||||
# Disable self-collision on all geoms.
|
||||
for geom in root.iter("geom"):
|
||||
geom.set("contype", "0")
|
||||
geom.set("conaffinity", "0")
|
||||
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
Path(tmp_path).write_text(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(tmp_path)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
# -- Runner -----------------------------------------------------------
|
||||
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
@@ -22,76 +209,62 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return self.config.num_envs
|
||||
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(self.config.device)
|
||||
|
||||
@staticmethod
|
||||
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) file and programmatically inject actuators.
|
||||
|
||||
Two-step approach required because MuJoCo's URDF parser ignores
|
||||
<actuator> in the <mujoco> extension block:
|
||||
1. Load the URDF → MuJoCo converts it to internal MJCF
|
||||
2. Export the MJCF XML, add <actuator> elements, reload
|
||||
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation
|
||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Step 4: Reload from modified MJCF
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
if model_path is None:
|
||||
raise ValueError("model_path must be specified in the environment config")
|
||||
|
||||
actuators = self.env.config.actuators
|
||||
self._model = self._load_model_with_actuators(str(model_path), actuators)
|
||||
self._model = load_mujoco_model(self.env.robot)
|
||||
self._model.opt.timestep = config.dt
|
||||
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
|
||||
|
||||
self._data: list[mujoco.MjData] = [
|
||||
mujoco.MjData(self._model) for _ in range(config.num_envs)
|
||||
]
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
self._limits = ActuatorLimits(self._model)
|
||||
|
||||
# Build motor model: list of (actuator_config, joint_qvel_index) for
|
||||
# actuators that have asymmetric motor dynamics.
|
||||
self._motor_actuators: list[tuple] = []
|
||||
for act in self.env.robot.actuators:
|
||||
if act.has_motor_model:
|
||||
jnt_id = mujoco.mj_name2id(self._model, mujoco.mjtObj.mjOBJ_JOINT, act.joint)
|
||||
qvel_idx = self._model.jnt_dofadr[jnt_id]
|
||||
self._motor_actuators.append((act, qvel_idx))
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
|
||||
# Apply per-actuator ctrl transform (deadzone + gear compensation)
|
||||
for act_idx, (act, _) in enumerate(self._motor_actuators):
|
||||
for env_idx in range(self.num_envs):
|
||||
actions_np[env_idx, act_idx] = act.transform_ctrl(
|
||||
float(actions_np[env_idx, act_idx])
|
||||
)
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
# Per-env domain-randomization scales (all 1.0 when DR is disabled).
|
||||
fr_scale = self._dr_scales["friction_scale"].cpu().numpy()
|
||||
dp_scale = self._dr_scales["damping_scale"].cpu().numpy()
|
||||
tq_scale = self._dr_scales["torque_scale"].cpu().numpy()
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
data.ctrl[:] = actions_np[i]
|
||||
# torque_scale emulates motor-constant / battery-voltage variation.
|
||||
data.ctrl[:] = actions_np[i] * tq_scale[i]
|
||||
for _ in range(self.config.substeps):
|
||||
# Apply asymmetric motor forces via qfrc_applied
|
||||
for act, qvel_idx in self._motor_actuators:
|
||||
vel = data.qvel[qvel_idx]
|
||||
ctrl = data.ctrl[0] # TODO: generalise for multi-actuator
|
||||
data.qfrc_applied[qvel_idx] = act.compute_motor_force(
|
||||
vel, ctrl,
|
||||
friction_scale=fr_scale[i],
|
||||
damping_scale=dp_scale[i],
|
||||
)
|
||||
self._limits.enforce(self._model, data)
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
@@ -101,55 +274,37 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
ids = env_ids.cpu().numpy()
|
||||
n = len(ids)
|
||||
|
||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||
# Env-defined initial-state distribution (shared with the MJX runner).
|
||||
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
|
||||
self._nq, self._nv,
|
||||
)
|
||||
|
||||
for i, env_id in enumerate(ids):
|
||||
for env_id in ids:
|
||||
data = self._data[env_id]
|
||||
mujoco.mj_resetData(self._model, data)
|
||||
|
||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
data.qpos[:] += np.random.uniform(qpos_lo, qpos_hi)
|
||||
data.qvel[:] += np.random.uniform(qvel_lo, qvel_hi)
|
||||
data.ctrl[:] = 0.0
|
||||
|
||||
# Full-batch return (see BaseRunner._sim_reset contract).
|
||||
qpos_batch = np.stack([d.qpos for d in self._data]).astype(np.float32)
|
||||
qvel_batch = np.stack([d.qvel for d in self._data]).astype(np.float32)
|
||||
return (
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
||||
self._viewer.close()
|
||||
self._viewer = None
|
||||
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
self._offscreen_renderer = None
|
||||
|
||||
self._data.clear()
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
if mode == "human":
|
||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
||||
self._viewer = mujoco.viewer.launch_passive(
|
||||
self._model, self._data[env_idx]
|
||||
)
|
||||
# Update visual geometry from current physics state
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._viewer.sync()
|
||||
return None
|
||||
elif mode == "rgb_array":
|
||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
||||
return torch.from_numpy(pixels)
|
||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Offscreen render of a single environment."""
|
||||
if not hasattr(self, "_offscreen_renderer"):
|
||||
self._offscreen_renderer = mujoco.Renderer(
|
||||
self._model, width=640, height=480,
|
||||
)
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
return self._offscreen_renderer.render().copy()
|
||||
|
||||
494
src/runners/serial.py
Normal file
494
src/runners/serial.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""Serial runner — real hardware over USB/serial (ESP32).
|
||||
|
||||
Implements the BaseRunner interface for a single physical robot.
|
||||
All physics come from the real world; the runner translates between
|
||||
the ESP32 serial protocol and the qpos/qvel tensors that BaseRunner
|
||||
and BaseEnv expect.
|
||||
|
||||
Serial protocol (ESP32 firmware):
|
||||
Commands sent TO the ESP32:
|
||||
G — start streaming state lines
|
||||
H — stop streaming
|
||||
M<int> — set motor PWM speed (-255 … 255)
|
||||
|
||||
State lines received FROM the ESP32 (firmware sends SI units):
|
||||
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
|
||||
(7 comma-separated fields after the ``S`` prefix)
|
||||
|
||||
motor_rad — motor joint angle (radians)
|
||||
motor_vel — motor joint velocity (rad/s)
|
||||
pend_rad — pendulum angle (radians, 0 = hanging down)
|
||||
pend_vel — pendulum angular velocity (rad/s)
|
||||
motor_speed — applied PWM (-255..255, for action recording)
|
||||
|
||||
A daemon thread continuously reads the serial stream so the control
|
||||
loop never blocks on I/O.
|
||||
|
||||
Usage:
|
||||
python train.py env=rotary_cartpole runner=serial training=ppo_real
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SerialRunnerConfig(BaseRunnerConfig):
|
||||
"""Configuration for serial communication with the ESP32."""
|
||||
|
||||
num_envs: int = 1 # always 1 — single physical robot
|
||||
device: str = "cpu"
|
||||
|
||||
port: str = "/dev/cu.usbserial-0001"
|
||||
baud: int = 115200
|
||||
dt: float = 0.04 # control loop period (seconds), 25 Hz
|
||||
no_data_timeout: float = 2.0 # seconds of silence → disconnect
|
||||
|
||||
# Physical reset procedure
|
||||
reset_drive_speed: int = 80 # PWM for bang-bang drive-to-center
|
||||
reset_deadband_rad: float = 0.01 # "centered" threshold (~0.6°)
|
||||
reset_drive_timeout: float = 3.0 # seconds to reach center
|
||||
reset_settle_timeout: float = 30.0 # seconds to wait for pendulum
|
||||
|
||||
|
||||
class SerialRunner(BaseRunner[SerialRunnerConfig]):
|
||||
"""BaseRunner implementation that talks to real hardware over serial.
|
||||
|
||||
Maps the ESP32 serial protocol to qpos/qvel tensors so the existing
|
||||
RotaryCartPoleEnv (or any compatible env) works unchanged.
|
||||
|
||||
qpos layout: [motor_angle_rad, pendulum_angle_rad]
|
||||
qvel layout: [motor_vel_rad_s, pendulum_vel_rad_s]
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseRunner interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _sim_initialize(self, config: SerialRunnerConfig) -> None:
|
||||
# Joint dimensions for the rotary cartpole (motor + pendulum).
|
||||
self._nq = 2
|
||||
self._nv = 2
|
||||
|
||||
# Import serial here so it's not a hard dependency for sim-only users.
|
||||
import serial as _serial
|
||||
|
||||
self._serial_mod = _serial
|
||||
|
||||
# Explicitly disable hardware flow control and exclusive mode to
|
||||
# avoid termios.error (errno 22) on macOS with CH340/CP2102 adapters.
|
||||
self.ser: _serial.Serial = _serial.Serial(
|
||||
port=config.port,
|
||||
baudrate=config.baud,
|
||||
timeout=0.05,
|
||||
xonxoff=False,
|
||||
rtscts=False,
|
||||
dsrdtr=False,
|
||||
exclusive=False,
|
||||
)
|
||||
time.sleep(2) # Wait for ESP32 boot.
|
||||
self.ser.reset_input_buffer()
|
||||
|
||||
# Internal state tracking.
|
||||
self._rebooted: bool = False
|
||||
self._serial_disconnected: bool = False
|
||||
self._last_esp_ms: int = 0
|
||||
self._last_data_time: float = time.monotonic()
|
||||
self._streaming: bool = False
|
||||
|
||||
# Latest parsed state (updated by the reader thread).
|
||||
# Firmware sends SI units — values are used directly as qpos/qvel.
|
||||
self._latest_state: dict[str, Any] = {
|
||||
"timestamp_ms": 0,
|
||||
"motor_rad": 0.0,
|
||||
"motor_vel": 0.0,
|
||||
"pend_rad": 0.0,
|
||||
"pend_vel": 0.0,
|
||||
"motor_speed": 0,
|
||||
}
|
||||
self._state_lock = threading.Lock()
|
||||
self._state_event = threading.Event()
|
||||
|
||||
# Start background serial reader.
|
||||
self._reader_running = True
|
||||
self._reader_thread = threading.Thread(
|
||||
target=self._serial_reader, daemon=True
|
||||
)
|
||||
self._reader_thread.start()
|
||||
|
||||
# Start streaming.
|
||||
self._send("G")
|
||||
self._streaming = True
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
# Derive max PWM from actuator ctrl_range so the serial
|
||||
# command range matches what MuJoCo training sees.
|
||||
ctrl_hi = self.env.robot.actuators[0].ctrl_range[1]
|
||||
self._max_pwm: int = round(ctrl_hi * 255)
|
||||
|
||||
# Track wall-clock time of last step for PPO-gap detection.
|
||||
self._last_step_time: float = time.monotonic()
|
||||
|
||||
def _sim_step(
|
||||
self, actions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
now = time.monotonic()
|
||||
|
||||
# Detect PPO update gap: if more than 0.5s since last step,
|
||||
# the optimizer was running and no motor commands were sent.
|
||||
# Trigger a full reset so the robot starts from a clean state.
|
||||
gap = now - self._last_step_time
|
||||
if gap > 0.5:
|
||||
logger.info(
|
||||
"PPO update gap detected (%.1f s) — resetting before resuming.",
|
||||
gap,
|
||||
)
|
||||
self._send("M0")
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
|
||||
# Map normalised action [-1, 1] → PWM, scaled by ctrl_range.
|
||||
action_val = float(actions[0, 0].clamp(-1.0, 1.0))
|
||||
motor_speed = int(action_val * self._max_pwm)
|
||||
self._send(f"M{motor_speed}")
|
||||
|
||||
# Stream-driven: block until the firmware sends the next state
|
||||
# line (~20 ms at 50 Hz).
|
||||
state = self._read_state_blocking(timeout=0.1)
|
||||
|
||||
# Firmware sends SI units — use directly.
|
||||
qpos, qvel = self._state_to_tensors(state)
|
||||
|
||||
# Cache for _sync_viz().
|
||||
self._last_sync_state = state
|
||||
self._last_step_time = time.monotonic()
|
||||
|
||||
return qpos, qvel
|
||||
|
||||
def _sim_reset(
|
||||
self, env_ids: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# If ESP32 rebooted or disconnected, we can't recover.
|
||||
if self._rebooted or self._serial_disconnected:
|
||||
raise RuntimeError(
|
||||
"ESP32 rebooted or disconnected during training! "
|
||||
"Encoder center is lost. "
|
||||
"Please re-center the motor manually and restart."
|
||||
)
|
||||
|
||||
# Stop motor and restart streaming.
|
||||
self._send("M0")
|
||||
self._send("H")
|
||||
self._streaming = False
|
||||
time.sleep(0.05)
|
||||
self._state_event.clear()
|
||||
self._send("G")
|
||||
self._streaming = True
|
||||
self._last_data_time = time.monotonic()
|
||||
time.sleep(0.05)
|
||||
|
||||
# Physically return the motor to the centre position.
|
||||
self._drive_to_center()
|
||||
|
||||
# Wait until the env considers the robot settled.
|
||||
self._wait_for_settle()
|
||||
|
||||
# Refresh data timer so health checks don't false-positive.
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
# Read settled state and return as qpos/qvel.
|
||||
state = self._read_state_blocking()
|
||||
qpos, qvel = self._state_to_tensors(state)
|
||||
|
||||
self._last_sync_state = state
|
||||
return qpos, qvel
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
self._reader_running = False
|
||||
self._streaming = False
|
||||
self._send("H")
|
||||
self._send("M0")
|
||||
time.sleep(0.1)
|
||||
self._reader_thread.join(timeout=1.0)
|
||||
self.ser.close()
|
||||
super()._sim_close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# MuJoCo digital-twin rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_viz_model(self) -> None:
|
||||
"""Lazily load the MuJoCo model for visualisation (digital twin)."""
|
||||
if hasattr(self, "_viz_model"):
|
||||
return
|
||||
|
||||
import mujoco
|
||||
from src.runners.mujoco import load_mujoco_model
|
||||
|
||||
self._viz_model = load_mujoco_model(self.env.robot)
|
||||
self._viz_data = mujoco.MjData(self._viz_model)
|
||||
self._offscreen_renderer = None
|
||||
|
||||
def _sync_viz(self) -> None:
|
||||
"""Copy current serial sensor state into the MuJoCo viz model."""
|
||||
import mujoco
|
||||
|
||||
self._ensure_viz_model()
|
||||
|
||||
last_state = getattr(self, "_last_sync_state", None)
|
||||
if last_state is None:
|
||||
last_state = self._read_state()
|
||||
|
||||
# Firmware sends radians — use directly.
|
||||
self._viz_data.qpos[0] = last_state["motor_rad"]
|
||||
self._viz_data.qpos[1] = last_state["pend_rad"]
|
||||
self._viz_data.qvel[0] = last_state["motor_vel"]
|
||||
self._viz_data.qvel[1] = last_state["pend_vel"]
|
||||
|
||||
mujoco.mj_forward(self._viz_model, self._viz_data)
|
||||
|
||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Offscreen render of the digital-twin MuJoCo model."""
|
||||
import mujoco
|
||||
|
||||
self._sync_viz()
|
||||
|
||||
if self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(
|
||||
self._viz_model, width=640, height=480,
|
||||
)
|
||||
self._offscreen_renderer.update_scene(self._viz_data)
|
||||
return self._offscreen_renderer.render().copy()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override step() for runner-level safety
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
# Check for ESP32 reboot / disconnect BEFORE stepping.
|
||||
if self._rebooted or self._serial_disconnected:
|
||||
self._send("M0")
|
||||
qpos, qvel = self._make_current_state()
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
reward = torch.tensor([[-100.0]])
|
||||
terminated = torch.tensor([[True]])
|
||||
truncated = torch.tensor([[False]])
|
||||
return obs, reward, terminated, truncated, {"reboot_detected": True}
|
||||
|
||||
# Normal step via BaseRunner (calls _sim_step → env logic).
|
||||
obs, rewards, terminated, truncated, info = super().step(actions)
|
||||
|
||||
# Check connection health after stepping.
|
||||
if not self._check_connection_health():
|
||||
self._send("M0")
|
||||
terminated = torch.tensor([[True]])
|
||||
rewards = torch.tensor([[-100.0]])
|
||||
info["reboot_detected"] = True
|
||||
|
||||
# Always stop motor on episode end.
|
||||
if terminated.any() or truncated.any():
|
||||
self._send("M0")
|
||||
|
||||
return obs, rewards, terminated, truncated, info
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serial helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send(self, cmd: str) -> None:
|
||||
"""Send a command to the ESP32."""
|
||||
try:
|
||||
self.ser.write(f"{cmd}\n".encode())
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
self._serial_disconnected = True
|
||||
|
||||
def _serial_reader(self) -> None:
|
||||
"""Background thread: continuously read and parse serial lines.
|
||||
|
||||
Protocol: ``S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>``
|
||||
(7 comma-separated fields). Firmware sends SI units directly.
|
||||
"""
|
||||
while self._reader_running:
|
||||
try:
|
||||
if self.ser.in_waiting:
|
||||
line = (
|
||||
self.ser.readline()
|
||||
.decode("utf-8", errors="ignore")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Detect ESP32 reboot: it prints READY on startup.
|
||||
if line.startswith("READY"):
|
||||
self._rebooted = True
|
||||
logger.critical("ESP32 reboot detected: %s", line)
|
||||
continue
|
||||
|
||||
if line.startswith("S,"):
|
||||
parts = line.split(",")
|
||||
if len(parts) >= 7:
|
||||
try:
|
||||
esp_ms = int(parts[1])
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Malformed state line (header): %s",
|
||||
line,
|
||||
)
|
||||
continue
|
||||
|
||||
# Detect reboot: timestamp jumped backwards.
|
||||
if (
|
||||
self._last_esp_ms > 5000
|
||||
and esp_ms < self._last_esp_ms - 3000
|
||||
):
|
||||
self._rebooted = True
|
||||
logger.critical(
|
||||
"ESP32 reboot detected: timestamp"
|
||||
" %d -> %d",
|
||||
self._last_esp_ms,
|
||||
esp_ms,
|
||||
)
|
||||
|
||||
self._last_esp_ms = esp_ms
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
try:
|
||||
parsed: dict[str, Any] = {
|
||||
"timestamp_ms": esp_ms,
|
||||
"motor_rad": float(parts[2]),
|
||||
"motor_vel": float(parts[3]),
|
||||
"pend_rad": float(parts[4]),
|
||||
"pend_vel": float(parts[5]),
|
||||
"motor_speed": int(parts[6]),
|
||||
}
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Malformed state line (fields): %s",
|
||||
line,
|
||||
)
|
||||
continue
|
||||
with self._state_lock:
|
||||
self._latest_state = parsed
|
||||
self._state_event.set()
|
||||
else:
|
||||
time.sleep(0.001) # Avoid busy-spinning.
|
||||
except (OSError, self._serial_mod.SerialException) as exc:
|
||||
self._serial_disconnected = True
|
||||
logger.critical("Serial connection lost: %s", exc)
|
||||
break
|
||||
|
||||
def _check_connection_health(self) -> bool:
|
||||
"""Return True if the ESP32 connection appears healthy."""
|
||||
if self._serial_disconnected:
|
||||
logger.critical("ESP32 serial connection lost.")
|
||||
return False
|
||||
if (
|
||||
self._streaming
|
||||
and (time.monotonic() - self._last_data_time)
|
||||
> self.config.no_data_timeout
|
||||
):
|
||||
logger.critical(
|
||||
"No data from ESP32 for %.1f s — possible crash/disconnect.",
|
||||
time.monotonic() - self._last_data_time,
|
||||
)
|
||||
self._rebooted = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def _read_state(self) -> dict[str, Any]:
|
||||
"""Return the most recent state from the reader thread (non-blocking)."""
|
||||
with self._state_lock:
|
||||
return dict(self._latest_state)
|
||||
|
||||
def _read_state_blocking(self, timeout: float = 0.05) -> dict[str, Any]:
|
||||
"""Wait for a fresh sample, then return it."""
|
||||
self._state_event.clear()
|
||||
self._state_event.wait(timeout=timeout)
|
||||
with self._state_lock:
|
||||
return dict(self._latest_state)
|
||||
|
||||
def _state_to_tensors(
|
||||
self, state: dict[str, Any],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert a parsed state dict to (qpos, qvel) tensors."""
|
||||
qpos = torch.tensor(
|
||||
[[state["motor_rad"], state["pend_rad"]]], dtype=torch.float32
|
||||
)
|
||||
qvel = torch.tensor(
|
||||
[[state["motor_vel"], state["pend_vel"]]], dtype=torch.float32
|
||||
)
|
||||
return qpos, qvel
|
||||
|
||||
def _make_current_state(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Build qpos/qvel from current sensor data (utility)."""
|
||||
return self._state_to_tensors(self._read_state_blocking())
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Physical reset helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drive_to_center(self) -> None:
|
||||
"""Drive the motor back toward center using bang-bang control."""
|
||||
cfg = self.config
|
||||
start = time.time()
|
||||
while time.time() - start < cfg.reset_drive_timeout:
|
||||
state = self._read_state_blocking()
|
||||
motor_rad = state["motor_rad"]
|
||||
if abs(motor_rad) < cfg.reset_deadband_rad:
|
||||
break
|
||||
speed = cfg.reset_drive_speed if motor_rad < 0 else -cfg.reset_drive_speed
|
||||
self._send(f"M{speed}")
|
||||
time.sleep(0.05)
|
||||
self._send("M0")
|
||||
time.sleep(0.2)
|
||||
|
||||
def _wait_for_settle(self) -> None:
|
||||
"""Block until the env considers the robot ready for a new episode."""
|
||||
cfg = self.config
|
||||
stable_since: float | None = None
|
||||
start = time.monotonic()
|
||||
|
||||
while time.monotonic() - start < cfg.reset_settle_timeout:
|
||||
state = self._read_state_blocking()
|
||||
qpos, qvel = self._state_to_tensors(state)
|
||||
|
||||
if self.env.is_reset_ready(qpos, qvel):
|
||||
if stable_since is None:
|
||||
stable_since = time.monotonic()
|
||||
elif time.monotonic() - stable_since >= 0.5:
|
||||
logger.info(
|
||||
"Robot settled after %.2f s",
|
||||
time.monotonic() - start,
|
||||
)
|
||||
return
|
||||
else:
|
||||
stable_since = None
|
||||
time.sleep(0.02)
|
||||
|
||||
logger.warning(
|
||||
"Robot did not settle within %.1f s — proceeding anyway.",
|
||||
cfg.reset_settle_timeout,
|
||||
)
|
||||
1
src/sysid/__init__.py
Normal file
1
src/sysid/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""System identification — tune simulation parameters to match real hardware."""
|
||||
79
src/sysid/_urdf.py
Normal file
79
src/sysid/_urdf.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""URDF XML helpers shared by sysid rollout and export modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
def set_mass(inertial: ET.Element, mass: float | None) -> None:
|
||||
if mass is None:
|
||||
return
|
||||
mass_el = inertial.find("mass")
|
||||
if mass_el is not None:
|
||||
mass_el.set("value", str(mass))
|
||||
|
||||
|
||||
def set_com(
|
||||
inertial: ET.Element,
|
||||
x: float | None,
|
||||
y: float | None,
|
||||
z: float | None,
|
||||
) -> None:
|
||||
origin = inertial.find("origin")
|
||||
if origin is None:
|
||||
return
|
||||
xyz = origin.get("xyz", "0 0 0").split()
|
||||
if x is not None:
|
||||
xyz[0] = str(x)
|
||||
if y is not None:
|
||||
xyz[1] = str(y)
|
||||
if z is not None:
|
||||
xyz[2] = str(z)
|
||||
origin.set("xyz", " ".join(xyz))
|
||||
|
||||
|
||||
def set_inertia(
|
||||
inertial: ET.Element,
|
||||
ixx: float | None = None,
|
||||
iyy: float | None = None,
|
||||
izz: float | None = None,
|
||||
ixy: float | None = None,
|
||||
iyz: float | None = None,
|
||||
ixz: float | None = None,
|
||||
) -> None:
|
||||
ine = inertial.find("inertia")
|
||||
if ine is None:
|
||||
return
|
||||
for attr, val in [
|
||||
("ixx", ixx), ("iyy", iyy), ("izz", izz),
|
||||
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
|
||||
]:
|
||||
if val is not None:
|
||||
ine.set(attr, str(val))
|
||||
|
||||
|
||||
def patch_link_inertials(
|
||||
root: ET.Element,
|
||||
params: dict[str, float],
|
||||
) -> None:
|
||||
"""Patch arm and pendulum inertial parameters in a URDF ElementTree root."""
|
||||
for link in root.iter("link"):
|
||||
link_name = link.get("name", "")
|
||||
inertial = link.find("inertial")
|
||||
if inertial is None:
|
||||
continue
|
||||
|
||||
if link_name == "arm":
|
||||
set_mass(inertial, params.get("arm_mass"))
|
||||
set_com(inertial, params.get("arm_com_x"),
|
||||
params.get("arm_com_y"), params.get("arm_com_z"))
|
||||
|
||||
elif link_name == "pendulum":
|
||||
set_mass(inertial, params.get("pendulum_mass"))
|
||||
set_com(inertial, params.get("pendulum_com_x"),
|
||||
params.get("pendulum_com_y"), params.get("pendulum_com_z"))
|
||||
set_inertia(inertial,
|
||||
ixx=params.get("pendulum_ixx"),
|
||||
iyy=params.get("pendulum_iyy"),
|
||||
izz=params.get("pendulum_izz"),
|
||||
ixy=params.get("pendulum_ixy"))
|
||||
434
src/sysid/capture.py
Normal file
434
src/sysid/capture.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Capture a real-robot trajectory under random excitation (PRBS-style).
|
||||
|
||||
Connects to the ESP32 over serial, sends random PWM commands to excite
|
||||
the system, and records motor + pendulum angles and velocities at ~50 Hz.
|
||||
|
||||
Saves a compressed numpy archive (.npz) that the optimizer can replay
|
||||
in simulation to fit physics parameters.
|
||||
|
||||
Serial protocol (same as SerialRunner):
|
||||
S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
|
||||
(7 comma-separated fields — firmware sends SI units)
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.capture \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--port /dev/cu.usbserial-0001 \
|
||||
--duration 20
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
|
||||
|
||||
|
||||
def _parse_state_line(line: str) -> dict[str, Any] | None:
|
||||
"""Parse an ``S,…`` state line from the ESP32.
|
||||
|
||||
Format: S,<ms>,<motor_rad>,<motor_vel>,<pend_rad>,<pend_vel>,<motor_speed>
|
||||
(7 comma-separated fields, firmware sends SI units)
|
||||
"""
|
||||
if not line.startswith("S,"):
|
||||
return None
|
||||
parts = line.split(",")
|
||||
if len(parts) < 7:
|
||||
return None
|
||||
try:
|
||||
return {
|
||||
"timestamp_ms": int(parts[1]),
|
||||
"motor_rad": float(parts[2]),
|
||||
"motor_vel": float(parts[3]),
|
||||
"pend_rad": float(parts[4]),
|
||||
"pend_vel": float(parts[5]),
|
||||
"motor_speed": int(parts[6]),
|
||||
}
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
# ── Background serial reader ─────────────────────────────────────────
|
||||
|
||||
|
||||
class _SerialReader:
|
||||
"""Minimal background reader for the ESP32 serial stream.
|
||||
|
||||
Uses a sequence counter so ``read_blocking()`` guarantees it returns
|
||||
a *new* state line (not a stale repeat). This keeps the capture
|
||||
loop locked to the firmware's 50 Hz tick.
|
||||
"""
|
||||
|
||||
def __init__(self, port: str, baud: int = 115200):
|
||||
import serial as _serial
|
||||
|
||||
self._serial_mod = _serial
|
||||
self.ser = _serial.Serial(port, baud, timeout=0.05)
|
||||
time.sleep(2) # Wait for ESP32 boot.
|
||||
self.ser.reset_input_buffer()
|
||||
|
||||
self._latest: dict[str, Any] = {}
|
||||
self._seq: int = 0 # incremented on every new state line
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
self._running = True
|
||||
|
||||
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
_debug_count = 0
|
||||
while self._running:
|
||||
try:
|
||||
if self.ser.in_waiting:
|
||||
line = (
|
||||
self.ser.readline()
|
||||
.decode("utf-8", errors="ignore")
|
||||
.strip()
|
||||
)
|
||||
# Debug: log first 10 raw lines so we can see what the firmware sends.
|
||||
if _debug_count < 10 and line:
|
||||
log.info("serial_raw_line", line=repr(line), count=_debug_count)
|
||||
_debug_count += 1
|
||||
parsed = _parse_state_line(line)
|
||||
if parsed is not None:
|
||||
with self._cond:
|
||||
self._latest = parsed
|
||||
self._seq += 1
|
||||
self._cond.notify_all()
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
log.critical("serial_lost")
|
||||
break
|
||||
|
||||
def send(self, cmd: str) -> None:
|
||||
try:
|
||||
self.ser.write(f"{cmd}\n".encode())
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
log.critical("serial_send_failed", cmd=cmd)
|
||||
|
||||
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
|
||||
"""Wait until a *new* state line arrives, then return it.
|
||||
|
||||
Uses a sequence counter to guarantee the returned state is
|
||||
different from whatever was available before this call.
|
||||
"""
|
||||
with self._cond:
|
||||
seq_before = self._seq
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._seq > seq_before, timeout=timeout
|
||||
):
|
||||
return {} # timeout — no new data
|
||||
return dict(self._latest)
|
||||
|
||||
def close(self) -> None:
|
||||
self._running = False
|
||||
self.send("H")
|
||||
self.send("M0")
|
||||
time.sleep(0.1)
|
||||
self._thread.join(timeout=1.0)
|
||||
self.ser.close()
|
||||
|
||||
|
||||
# ── PRBS excitation signal ───────────────────────────────────────────
|
||||
|
||||
|
||||
class _PRBSExcitation:
|
||||
"""Random hold-value excitation with configurable amplitude and hold time.
|
||||
|
||||
At each call to ``__call__``, returns the current PWM value.
|
||||
The value is held for a random duration (``hold_min``–``hold_max`` ms),
|
||||
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
amplitude: int = 150,
|
||||
hold_min_ms: int = 50,
|
||||
hold_max_ms: int = 300,
|
||||
):
|
||||
self.amplitude = amplitude
|
||||
self.hold_min_ms = hold_min_ms
|
||||
self.hold_max_ms = hold_max_ms
|
||||
self._current: int = 0
|
||||
self._switch_time: float = 0.0
|
||||
self._new_value()
|
||||
|
||||
def _new_value(self) -> None:
|
||||
self._current = random.randint(-self.amplitude, self.amplitude)
|
||||
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
|
||||
self._switch_time = time.monotonic() + hold_ms / 1000.0
|
||||
|
||||
def __call__(self) -> int:
|
||||
if time.monotonic() >= self._switch_time:
|
||||
self._new_value()
|
||||
return self._current
|
||||
|
||||
|
||||
# ── Main capture loop ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def capture(
|
||||
robot_path: str | Path,
|
||||
port: str = "/dev/cu.usbserial-0001",
|
||||
baud: int = 115200,
|
||||
duration: float = 20.0,
|
||||
amplitude: int = 150,
|
||||
hold_min_ms: int = 50,
|
||||
hold_max_ms: int = 300,
|
||||
dt: float = 0.02,
|
||||
motor_angle_limit_deg: float = 90.0,
|
||||
) -> Path:
|
||||
"""Run the capture procedure and return the path to the saved .npz file.
|
||||
|
||||
The capture loop is **stream-driven**: it blocks on each incoming
|
||||
state line from the firmware (which arrives at 50 Hz), sends the
|
||||
next motor command immediately, and records both.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : path to robot asset directory
|
||||
port : serial port for ESP32
|
||||
baud : baud rate
|
||||
duration : capture duration in seconds
|
||||
amplitude : max PWM magnitude for excitation
|
||||
hold_min_ms / hold_max_ms : random hold time range (ms)
|
||||
dt : nominal sample period for buffer sizing (seconds)
|
||||
motor_angle_limit_deg : safety limit for motor angle
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
|
||||
max_motor_rad = math.radians(motor_angle_limit_deg) if motor_angle_limit_deg > 0 else 0.0
|
||||
|
||||
# Connect.
|
||||
reader = _SerialReader(port, baud)
|
||||
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
|
||||
|
||||
# Prepare recording buffers (generous headroom).
|
||||
max_samples = int(duration / dt) + 500
|
||||
rec_time = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_action = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
|
||||
|
||||
# Start streaming.
|
||||
reader.send("G")
|
||||
time.sleep(0.1)
|
||||
|
||||
log.info(
|
||||
"capture_starting",
|
||||
port=port,
|
||||
duration=duration,
|
||||
amplitude=amplitude,
|
||||
hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
|
||||
mode="stream-driven (firmware clock)",
|
||||
)
|
||||
|
||||
idx = 0
|
||||
pwm = 0
|
||||
last_esp_ms = -1 # firmware timestamp of last recorded sample
|
||||
esp_ms_origin: int | None = None # first firmware timestamp
|
||||
no_data_count = 0 # consecutive timeouts with no data
|
||||
t0 = time.monotonic() # host clock for duration check only
|
||||
try:
|
||||
while True:
|
||||
# Block until the firmware sends the next state line (~20 ms).
|
||||
# Timeout at 100 ms prevents hanging if the ESP32 disconnects.
|
||||
state = reader.read_blocking(timeout=0.1)
|
||||
if not state:
|
||||
no_data_count += 1
|
||||
if no_data_count == 30: # 3 seconds with no data
|
||||
log.warning(
|
||||
"no_data_received",
|
||||
msg="No state lines from firmware after 3s. "
|
||||
"Check: is the ESP32 powered? Is it running the right firmware? "
|
||||
"Try pressing the RESET button.",
|
||||
)
|
||||
if no_data_count == 100: # 10 seconds
|
||||
log.critical(
|
||||
"no_data_timeout",
|
||||
msg="No data for 10s — aborting capture.",
|
||||
)
|
||||
break
|
||||
continue # no data yet — retry
|
||||
no_data_count = 0
|
||||
|
||||
# Deduplicate: the firmware may send multiple state lines per
|
||||
# tick (e.g. M-command echo + tick). Only record one sample
|
||||
# per unique firmware timestamp.
|
||||
esp_ms = state.get("timestamp_ms", 0)
|
||||
if esp_ms == last_esp_ms:
|
||||
continue
|
||||
last_esp_ms = esp_ms
|
||||
|
||||
# Use firmware clock for time axis (avoids host serial jitter).
|
||||
if esp_ms_origin is None:
|
||||
esp_ms_origin = esp_ms
|
||||
elapsed = (esp_ms - esp_ms_origin) / 1000.0
|
||||
if elapsed >= duration:
|
||||
break
|
||||
|
||||
# Get excitation PWM for the NEXT tick.
|
||||
pwm = excitation()
|
||||
|
||||
# Safety: keep the arm well within its mechanical range.
|
||||
# Firmware sends motor angle in radians — use directly.
|
||||
motor_angle_rad = state.get("motor_rad", 0.0)
|
||||
if max_motor_rad > 0:
|
||||
ratio = motor_angle_rad / max_motor_rad # signed, -1..+1
|
||||
abs_ratio = abs(ratio)
|
||||
|
||||
if abs_ratio > 0.90:
|
||||
# Deep in the danger zone — force a strong return.
|
||||
brake_strength = min(1.0, (abs_ratio - 0.90) / 0.10) # 0→1
|
||||
brake_pwm = int(amplitude * (0.5 + 0.5 * brake_strength))
|
||||
pwm = -brake_pwm if ratio > 0 else brake_pwm
|
||||
elif abs_ratio > 0.70:
|
||||
# Soft zone — only allow actions pointing back to centre.
|
||||
if ratio > 0 and pwm > 0:
|
||||
pwm = -abs(pwm)
|
||||
elif ratio < 0 and pwm < 0:
|
||||
pwm = abs(pwm)
|
||||
|
||||
# Send command immediately — it will take effect on the next tick.
|
||||
reader.send(f"M{pwm}")
|
||||
|
||||
# Record this tick's state + the action the motor *actually*
|
||||
# received. Firmware sends SI units — use directly.
|
||||
motor_angle = state.get("motor_rad", 0.0)
|
||||
motor_vel = state.get("motor_vel", 0.0)
|
||||
pend_angle = state.get("pend_rad", 0.0)
|
||||
pend_vel = state.get("pend_vel", 0.0)
|
||||
# Firmware constrains to ±255; normalise to [-1, 1].
|
||||
applied = state.get("motor_speed", 0)
|
||||
action_norm = max(-255, min(255, applied)) / 255.0
|
||||
|
||||
if idx < max_samples:
|
||||
rec_time[idx] = elapsed
|
||||
rec_action[idx] = action_norm
|
||||
rec_motor_angle[idx] = motor_angle
|
||||
rec_motor_vel[idx] = motor_vel
|
||||
rec_pend_angle[idx] = pend_angle
|
||||
rec_pend_vel[idx] = pend_vel
|
||||
idx += 1
|
||||
else:
|
||||
break # buffer full
|
||||
|
||||
# Progress (every 50 samples ≈ once per second at 50 Hz).
|
||||
if idx % 50 == 0:
|
||||
log.info(
|
||||
"capture_progress",
|
||||
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
|
||||
samples=idx,
|
||||
pwm=pwm,
|
||||
)
|
||||
|
||||
finally:
|
||||
reader.send("M0")
|
||||
reader.close()
|
||||
|
||||
# Trim to actual sample count.
|
||||
rec_time = rec_time[:idx]
|
||||
rec_action = rec_action[:idx]
|
||||
rec_motor_angle = rec_motor_angle[:idx]
|
||||
rec_motor_vel = rec_motor_vel[:idx]
|
||||
rec_pend_angle = rec_pend_angle[:idx]
|
||||
rec_pend_vel = rec_pend_vel[:idx]
|
||||
|
||||
# Save.
|
||||
recordings_dir = robot_path / "recordings"
|
||||
recordings_dir.mkdir(exist_ok=True)
|
||||
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = recordings_dir / f"capture_{stamp}.npz"
|
||||
np.savez_compressed(
|
||||
out_path,
|
||||
time=rec_time,
|
||||
action=rec_action,
|
||||
motor_angle=rec_motor_angle,
|
||||
motor_vel=rec_motor_vel,
|
||||
pendulum_angle=rec_pend_angle,
|
||||
pendulum_vel=rec_pend_vel,
|
||||
)
|
||||
|
||||
log.info(
|
||||
"capture_saved",
|
||||
path=str(out_path),
|
||||
samples=idx,
|
||||
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
|
||||
)
|
||||
return out_path
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Capture a real-robot trajectory for system identification."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
help="Path to robot asset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="/dev/cu.usbserial-0001",
|
||||
help="Serial port for ESP32",
|
||||
)
|
||||
parser.add_argument("--baud", type=int, default=115200)
|
||||
parser.add_argument(
|
||||
"--duration", type=float, default=20.0, help="Capture duration (s)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amplitude", type=int, default=150,
|
||||
help="Max PWM magnitude for excitation (0-255)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dt", type=float, default=0.02, help="Nominal sample period for buffer sizing (s)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--motor-angle-limit", type=float, default=90.0,
|
||||
help="Motor angle safety limit in degrees (0 = disabled)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
capture(
|
||||
robot_path=args.robot_path,
|
||||
port=args.port,
|
||||
baud=args.baud,
|
||||
duration=args.duration,
|
||||
amplitude=args.amplitude,
|
||||
hold_min_ms=args.hold_min_ms,
|
||||
hold_max_ms=args.hold_max_ms,
|
||||
dt=args.dt,
|
||||
motor_angle_limit_deg=args.motor_angle_limit,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
182
src/sysid/export.py
Normal file
182
src/sysid/export.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Export tuned parameters to URDF and robot.yaml files.
|
||||
|
||||
Reads the original files, injects the optimised parameter values,
|
||||
and writes ``rotary_cartpole_tuned.urdf`` + ``robot_tuned.yaml``
|
||||
alongside the originals in the robot asset directory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
from src.sysid._urdf import patch_link_inertials
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def export_tuned_files(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
motor_params: dict[str, float] | None = None,
|
||||
) -> tuple[Path, Path]:
|
||||
"""Write tuned URDF and robot.yaml files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : robot asset directory (contains robot.yaml + *.urdf)
|
||||
params : dict of parameter name → tuned value (the optimised set)
|
||||
motor_params : locked motor parameters merged underneath ``params``
|
||||
(``params`` wins on conflicts) so the exported YAML always has a
|
||||
complete motor model
|
||||
|
||||
Returns
|
||||
-------
|
||||
(tuned_urdf_path, tuned_robot_yaml_path)
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
if motor_params:
|
||||
params = {**motor_params, **params}
|
||||
|
||||
# ── Load originals ───────────────────────────────────────────
|
||||
robot_yaml_path = robot_path / "robot.yaml"
|
||||
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
|
||||
urdf_path = robot_path / robot_cfg["urdf"]
|
||||
|
||||
# ── Tune URDF ────────────────────────────────────────────────
|
||||
tree = ET.parse(urdf_path)
|
||||
patch_link_inertials(tree.getroot(), params)
|
||||
|
||||
# Write tuned URDF.
|
||||
tuned_urdf_name = urdf_path.stem + "_tuned" + urdf_path.suffix
|
||||
tuned_urdf_path = robot_path / tuned_urdf_name
|
||||
|
||||
# Preserve the XML declaration and original formatting as much as possible.
|
||||
ET.indent(tree, space=" ")
|
||||
tree.write(str(tuned_urdf_path), xml_declaration=True, encoding="unicode")
|
||||
log.info("tuned_urdf_written", path=str(tuned_urdf_path))
|
||||
|
||||
# ── Tune robot.yaml ──────────────────────────────────────────
|
||||
tuned_cfg = copy.deepcopy(robot_cfg)
|
||||
|
||||
# Point to the tuned URDF.
|
||||
tuned_cfg["urdf"] = tuned_urdf_name
|
||||
|
||||
# Update actuator parameters — full asymmetric motor model.
|
||||
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
||||
act = tuned_cfg["actuators"][0]
|
||||
|
||||
# Asymmetric gear, damping, deadzone, frictionloss as [pos, neg].
|
||||
act["gear"] = [
|
||||
round(params.get("actuator_gear_pos", 0.424), 6),
|
||||
round(params.get("actuator_gear_neg", 0.425), 6),
|
||||
]
|
||||
act["damping"] = [
|
||||
round(params.get("motor_damping_pos", 0.002), 6),
|
||||
round(params.get("motor_damping_neg", 0.015), 6),
|
||||
]
|
||||
act["deadzone"] = [
|
||||
round(params.get("motor_deadzone_pos", 0.141), 6),
|
||||
round(params.get("motor_deadzone_neg", 0.078), 6),
|
||||
]
|
||||
act["frictionloss"] = [
|
||||
round(params.get("motor_frictionloss_pos", 0.057), 6),
|
||||
round(params.get("motor_frictionloss_neg", 0.053), 6),
|
||||
]
|
||||
if "actuator_filter_tau" in params:
|
||||
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
||||
|
||||
# Stribeck friction and action bias.
|
||||
if "stribeck_friction_boost" in params:
|
||||
act["stribeck_friction_boost"] = round(params["stribeck_friction_boost"], 6)
|
||||
if "stribeck_vel" in params:
|
||||
act["stribeck_vel"] = round(params["stribeck_vel"], 6)
|
||||
if "action_bias" in params:
|
||||
act["action_bias"] = round(params["action_bias"], 6)
|
||||
|
||||
# ctrl_range from ctrl_limit parameter.
|
||||
if "ctrl_limit" in params:
|
||||
lim = round(params["ctrl_limit"], 6)
|
||||
act["ctrl_range"] = [-lim, lim]
|
||||
|
||||
# Update joint overrides.
|
||||
if "joints" not in tuned_cfg:
|
||||
tuned_cfg["joints"] = {}
|
||||
|
||||
if "motor_joint" not in tuned_cfg["joints"]:
|
||||
tuned_cfg["joints"]["motor_joint"] = {}
|
||||
mj = tuned_cfg["joints"]["motor_joint"]
|
||||
if "motor_armature" in params:
|
||||
mj["armature"] = round(params["motor_armature"], 6)
|
||||
# Frictionloss/damping = 0 in MuJoCo (motor model handles via qfrc_applied).
|
||||
mj["frictionloss"] = 0.0
|
||||
|
||||
if "pendulum_joint" not in tuned_cfg["joints"]:
|
||||
tuned_cfg["joints"]["pendulum_joint"] = {}
|
||||
pj = tuned_cfg["joints"]["pendulum_joint"]
|
||||
if "pendulum_damping" in params:
|
||||
pj["damping"] = round(params["pendulum_damping"], 6)
|
||||
if "pendulum_frictionloss" in params:
|
||||
pj["frictionloss"] = round(params["pendulum_frictionloss"], 6)
|
||||
|
||||
# Write tuned robot.yaml.
|
||||
tuned_yaml_path = robot_path / "robot_tuned.yaml"
|
||||
|
||||
# Add a header comment.
|
||||
header = (
|
||||
"# Tuned robot config — generated by src.sysid.optimize\n"
|
||||
"# Original: robot.yaml\n"
|
||||
"# Run `python -m src.sysid.visualize` to compare real vs sim.\n\n"
|
||||
)
|
||||
tuned_yaml_path.write_text(
|
||||
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
|
||||
)
|
||||
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
|
||||
|
||||
return tuned_urdf_path, tuned_yaml_path
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import argparse
|
||||
import json
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export tuned URDF + robot.yaml from sysid results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path", type=str, default="assets/rotary_cartpole",
|
||||
help="Path to robot asset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result", type=str, default=None,
|
||||
help="Path to sysid_result.json (auto-detected if omitted)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_path = Path(args.robot_path).resolve()
|
||||
if args.result:
|
||||
result_path = Path(args.result)
|
||||
else:
|
||||
result_path = robot_path / "sysid_result.json"
|
||||
|
||||
if not result_path.exists():
|
||||
raise FileNotFoundError(f"Result file not found: {result_path}")
|
||||
|
||||
result = json.loads(result_path.read_text())
|
||||
|
||||
export_tuned_files(
|
||||
robot_path=args.robot_path,
|
||||
params=result["best_params"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
186
src/sysid/motor/export.py
Normal file
186
src/sysid/motor/export.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Export tuned motor parameters to MJCF and robot.yaml files.
|
||||
|
||||
Reads the original motor.xml and robot.yaml, patches with optimised
|
||||
parameter values, and writes motor_tuned.xml + robot_tuned.yaml.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.motor.export \
|
||||
--asset-path assets/motor \
|
||||
--result assets/motor/motor_sysid_result.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
_DEFAULT_ASSET = "assets/motor"
|
||||
|
||||
|
||||
def export_tuned_files(
|
||||
asset_path: str | Path,
|
||||
params: dict[str, float],
|
||||
) -> tuple[Path, Path]:
|
||||
"""Write tuned MJCF and robot.yaml files.
|
||||
|
||||
Returns (tuned_mjcf_path, tuned_robot_yaml_path).
|
||||
"""
|
||||
asset_path = Path(asset_path).resolve()
|
||||
|
||||
robot_yaml_path = asset_path / "robot.yaml"
|
||||
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
|
||||
mjcf_path = asset_path / robot_cfg["mjcf"]
|
||||
|
||||
# ── Tune MJCF ────────────────────────────────────────────────
|
||||
tree = ET.parse(str(mjcf_path))
|
||||
root = tree.getroot()
|
||||
|
||||
# Actuator — use average gear for the MJCF model.
|
||||
gear_pos = params.get("actuator_gear_pos", params.get("actuator_gear"))
|
||||
gear_neg = params.get("actuator_gear_neg", params.get("actuator_gear"))
|
||||
gear_avg = None
|
||||
if gear_pos is not None and gear_neg is not None:
|
||||
gear_avg = (gear_pos + gear_neg) / 2.0
|
||||
elif gear_pos is not None:
|
||||
gear_avg = gear_pos
|
||||
filter_tau = params.get("actuator_filter_tau")
|
||||
for act_el in root.iter("general"):
|
||||
if act_el.get("name") == "motor":
|
||||
if gear_avg is not None:
|
||||
act_el.set("gear", str(gear_avg))
|
||||
if filter_tau is not None:
|
||||
if filter_tau > 0:
|
||||
act_el.set("dyntype", "filter")
|
||||
act_el.set("dynprm", str(filter_tau))
|
||||
else:
|
||||
act_el.set("dyntype", "none")
|
||||
|
||||
# Joint — average damping & friction for MJCF (asymmetry in runtime).
|
||||
fl_pos = params.get("motor_frictionloss_pos", params.get("motor_frictionloss"))
|
||||
fl_neg = params.get("motor_frictionloss_neg", params.get("motor_frictionloss"))
|
||||
fl_avg = None
|
||||
if fl_pos is not None and fl_neg is not None:
|
||||
fl_avg = (fl_pos + fl_neg) / 2.0
|
||||
elif fl_pos is not None:
|
||||
fl_avg = fl_pos
|
||||
damp_pos = params.get("motor_damping_pos", params.get("motor_damping"))
|
||||
damp_neg = params.get("motor_damping_neg", params.get("motor_damping"))
|
||||
damp_avg = None
|
||||
if damp_pos is not None and damp_neg is not None:
|
||||
damp_avg = (damp_pos + damp_neg) / 2.0
|
||||
elif damp_pos is not None:
|
||||
damp_avg = damp_pos
|
||||
for jnt in root.iter("joint"):
|
||||
if jnt.get("name") == "motor_joint":
|
||||
if damp_avg is not None:
|
||||
jnt.set("damping", str(damp_avg))
|
||||
if "motor_armature" in params:
|
||||
jnt.set("armature", str(params["motor_armature"]))
|
||||
if fl_avg is not None:
|
||||
jnt.set("frictionloss", str(fl_avg))
|
||||
|
||||
# Rotor mass.
|
||||
if "rotor_mass" in params:
|
||||
for geom in root.iter("geom"):
|
||||
if geom.get("name") == "rotor_disk":
|
||||
geom.set("mass", str(params["rotor_mass"]))
|
||||
|
||||
# Write tuned MJCF.
|
||||
tuned_mjcf_name = mjcf_path.stem + "_tuned" + mjcf_path.suffix
|
||||
tuned_mjcf_path = asset_path / tuned_mjcf_name
|
||||
ET.indent(tree, space=" ")
|
||||
tree.write(str(tuned_mjcf_path), xml_declaration=True, encoding="unicode")
|
||||
log.info("tuned_mjcf_written", path=str(tuned_mjcf_path))
|
||||
|
||||
# ── Tune robot.yaml ──────────────────────────────────────────
|
||||
tuned_cfg = copy.deepcopy(robot_cfg)
|
||||
tuned_cfg["mjcf"] = tuned_mjcf_name
|
||||
|
||||
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
||||
act = tuned_cfg["actuators"][0]
|
||||
if gear_avg is not None:
|
||||
act["gear"] = round(gear_avg, 6)
|
||||
if "actuator_filter_tau" in params:
|
||||
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
||||
if "motor_damping" in params:
|
||||
act["damping"] = round(params["motor_damping"], 6)
|
||||
|
||||
if "joints" not in tuned_cfg:
|
||||
tuned_cfg["joints"] = {}
|
||||
if "motor_joint" not in tuned_cfg["joints"]:
|
||||
tuned_cfg["joints"]["motor_joint"] = {}
|
||||
mj = tuned_cfg["joints"]["motor_joint"]
|
||||
if "motor_armature" in params:
|
||||
mj["armature"] = round(params["motor_armature"], 6)
|
||||
if fl_avg is not None:
|
||||
mj["frictionloss"] = round(fl_avg, 6)
|
||||
|
||||
# Asymmetric / hardware-realism / nonlinear parameters.
|
||||
realism = {}
|
||||
for key in [
|
||||
"actuator_gear_pos", "actuator_gear_neg",
|
||||
"motor_damping_pos", "motor_damping_neg",
|
||||
"motor_frictionloss_pos", "motor_frictionloss_neg",
|
||||
"motor_deadzone_pos", "motor_deadzone_neg",
|
||||
"action_bias",
|
||||
"viscous_quadratic", "back_emf_gain",
|
||||
"stribeck_friction_boost", "stribeck_vel",
|
||||
"gearbox_backlash",
|
||||
]:
|
||||
if key in params:
|
||||
realism[key] = round(params[key], 6)
|
||||
if realism:
|
||||
tuned_cfg["hardware_realism"] = realism
|
||||
|
||||
tuned_yaml_path = asset_path / "robot_tuned.yaml"
|
||||
header = (
|
||||
"# Tuned motor config — generated by src.sysid.motor.optimize\n"
|
||||
"# Original: robot.yaml\n\n"
|
||||
)
|
||||
tuned_yaml_path.write_text(
|
||||
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
|
||||
)
|
||||
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
|
||||
|
||||
return tuned_mjcf_path, tuned_yaml_path
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export tuned motor parameters to MJCF + robot.yaml."
|
||||
)
|
||||
parser.add_argument("--asset-path", type=str, default=_DEFAULT_ASSET)
|
||||
parser.add_argument(
|
||||
"--result", type=str, default=None,
|
||||
help="Path to motor_sysid_result.json (auto-detected if omitted)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asset_path = Path(args.asset_path).resolve()
|
||||
if args.result:
|
||||
result_path = Path(args.result)
|
||||
else:
|
||||
result_path = asset_path / "motor_sysid_result.json"
|
||||
|
||||
if not result_path.exists():
|
||||
raise FileNotFoundError(f"Result file not found: {result_path}")
|
||||
|
||||
result = json.loads(result_path.read_text())
|
||||
params = result["best_params"]
|
||||
|
||||
export_tuned_files(asset_path=args.asset_path, params=params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
580
src/sysid/optimize.py
Normal file
580
src/sysid/optimize.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""CMA-ES optimiser — fit simulation parameters to a real-robot recording.
|
||||
|
||||
Minimises the trajectory-matching cost between a MuJoCo rollout and a
|
||||
recorded real-robot sequence. Uses the ``cmaes`` package (pure-Python
|
||||
CMA-ES with native box-constraint support).
|
||||
|
||||
Motor parameters are **locked** from the motor-only sysid — only
|
||||
pendulum/arm inertial parameters, joint dynamics, and ctrl_limit are
|
||||
optimised. Velocities are optionally preprocessed with Savitzky-Golay
|
||||
differentiation for cleaner targets.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.optimize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording assets/rotary_cartpole/recordings/capture_20260314_000435.npz
|
||||
|
||||
# Shorter run for testing:
|
||||
python -m src.sysid.optimize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording <file>.npz \
|
||||
--max-generations 10 --population-size 8
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
from src.sysid.rollout import (
|
||||
LOCKED_MOTOR_PARAMS,
|
||||
PARAM_SETS,
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
ParamSpec,
|
||||
bounds_arrays,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
rollout,
|
||||
windowed_rollout,
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
# ── Velocity preprocessing ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _preprocess_recording(
|
||||
recording: dict[str, np.ndarray],
|
||||
preprocess_vel: bool = True,
|
||||
sg_window: int = 7,
|
||||
sg_polyorder: int = 3,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Optionally recompute velocities using Savitzky-Golay differentiation.
|
||||
|
||||
Applies SG filtering to both motor_vel and pendulum_vel, replacing
|
||||
the noisy firmware finite-difference velocities with smooth
|
||||
analytical derivatives of the (clean) angle signals.
|
||||
"""
|
||||
if not preprocess_vel:
|
||||
return recording
|
||||
|
||||
from scipy.signal import savgol_filter
|
||||
|
||||
rec = dict(recording)
|
||||
times = rec["time"]
|
||||
dt = float(np.mean(np.diff(times)))
|
||||
|
||||
# Motor velocity.
|
||||
rec["motor_vel_raw"] = rec["motor_vel"].copy()
|
||||
rec["motor_vel"] = savgol_filter(
|
||||
rec["motor_angle"],
|
||||
window_length=sg_window,
|
||||
polyorder=sg_polyorder,
|
||||
deriv=1,
|
||||
delta=dt,
|
||||
)
|
||||
|
||||
# Pendulum velocity.
|
||||
rec["pendulum_vel_raw"] = rec["pendulum_vel"].copy()
|
||||
rec["pendulum_vel"] = savgol_filter(
|
||||
rec["pendulum_angle"],
|
||||
window_length=sg_window,
|
||||
polyorder=sg_polyorder,
|
||||
deriv=1,
|
||||
delta=dt,
|
||||
)
|
||||
|
||||
motor_noise = np.std(rec["motor_vel_raw"] - rec["motor_vel"])
|
||||
pend_noise = np.std(rec["pendulum_vel_raw"] - rec["pendulum_vel"])
|
||||
log.info(
|
||||
"velocity_preprocessed",
|
||||
method="savgol",
|
||||
sg_window=sg_window,
|
||||
sg_polyorder=sg_polyorder,
|
||||
dt_ms=f"{dt*1000:.1f}",
|
||||
motor_noise_std=f"{motor_noise:.3f} rad/s",
|
||||
pend_noise_std=f"{pend_noise:.3f} rad/s",
|
||||
)
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
# ── Cost function ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _angle_diff(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
"""Shortest signed angle difference, handling wrapping."""
|
||||
return np.arctan2(np.sin(a - b), np.cos(a - b))
|
||||
|
||||
|
||||
def _check_inertia_valid(params: dict[str, float]) -> bool:
|
||||
"""Quick reject: pendulum inertia tensor must be positive-definite."""
|
||||
ixx = params.get("pendulum_ixx", 6.16e-06)
|
||||
iyy = params.get("pendulum_iyy", 6.16e-06)
|
||||
izz = params.get("pendulum_izz", 1.23e-05)
|
||||
ixy = params.get("pendulum_ixy", 6.10e-06)
|
||||
det_xy = ixx * iyy - ixy * ixy
|
||||
return det_xy > 0 and ixx > 0 and iyy > 0 and izz > 0
|
||||
|
||||
|
||||
def _compute_trajectory_cost(
|
||||
sim: dict[str, np.ndarray],
|
||||
recording: dict[str, np.ndarray],
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
pendulum_scale: float = 3.0,
|
||||
vel_outlier_threshold: float = 20.0,
|
||||
) -> float:
|
||||
"""Weighted MSE between sim and real trajectories.
|
||||
|
||||
pendulum_scale multiplies the pendulum terms relative to motor terms.
|
||||
|
||||
Samples where the *real* pendulum velocity exceeds
|
||||
``vel_outlier_threshold`` (rad/s) are excluded from the velocity
|
||||
terms. These are encoder-wrap artefacts (pendulum swinging past
|
||||
vertical) that no simulator can reproduce.
|
||||
"""
|
||||
motor_err = _angle_diff(sim["motor_angle"], recording["motor_angle"])
|
||||
pend_err = _angle_diff(sim["pendulum_angle"], recording["pendulum_angle"])
|
||||
motor_vel_err = sim["motor_vel"] - recording["motor_vel"]
|
||||
pend_vel_err = sim["pendulum_vel"] - recording["pendulum_vel"]
|
||||
|
||||
# Mask out encoder-wrap velocity spikes so the optimizer doesn't
|
||||
# waste capacity fitting artefacts.
|
||||
valid = np.abs(recording["pendulum_vel"]) < vel_outlier_threshold
|
||||
if valid.sum() < len(valid):
|
||||
motor_vel_err = motor_vel_err[valid]
|
||||
pend_vel_err = pend_vel_err[valid]
|
||||
|
||||
return float(
|
||||
pos_weight * np.mean(motor_err**2)
|
||||
+ pos_weight * pendulum_scale * np.mean(pend_err**2)
|
||||
+ vel_weight * np.mean(motor_vel_err**2)
|
||||
+ vel_weight * pendulum_scale * np.mean(pend_vel_err**2)
|
||||
)
|
||||
|
||||
|
||||
def cost_function(
|
||||
params_vec: np.ndarray,
|
||||
recording: dict[str, np.ndarray],
|
||||
robot_path: Path,
|
||||
specs: list[ParamSpec],
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
pendulum_scale: float = 3.0,
|
||||
window_duration: float = 0.5,
|
||||
motor_params: dict[str, float] | None = None,
|
||||
) -> float:
|
||||
"""Compute trajectory-matching cost for a candidate parameter vector.
|
||||
|
||||
Uses **multiple-shooting** (windowed rollout): the recording is split
|
||||
into short windows (default 0.5 s). Each window is initialised from
|
||||
the real qpos/qvel, so early errors don’t compound across the full
|
||||
trajectory. This gives a much smoother cost landscape for CMA-ES.
|
||||
|
||||
Set ``window_duration=0`` to fall back to the original open-loop
|
||||
single-shot rollout (not recommended).
|
||||
"""
|
||||
params = params_to_dict(params_vec, specs)
|
||||
|
||||
if not _check_inertia_valid(params):
|
||||
return 1e6
|
||||
|
||||
try:
|
||||
if window_duration > 0:
|
||||
sim = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
motor_params=motor_params,
|
||||
)
|
||||
else:
|
||||
sim = rollout(
|
||||
robot_path=robot_path,
|
||||
params=params,
|
||||
actions=recording["action"],
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
motor_params=motor_params,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("rollout_failed", error=str(exc))
|
||||
return 1e6
|
||||
|
||||
# Check for NaN in sim output.
|
||||
for key in ("motor_angle", "motor_vel", "pendulum_angle", "pendulum_vel"):
|
||||
if np.any(~np.isfinite(sim[key])):
|
||||
return 1e6
|
||||
|
||||
return _compute_trajectory_cost(sim, recording, pos_weight, vel_weight, pendulum_scale)
|
||||
|
||||
|
||||
# ── Parallel evaluation helper (module-level for pickling) ───────────
|
||||
|
||||
# Shared state set by the parent process before spawning workers.
|
||||
_par_recording: dict[str, np.ndarray] = {}
|
||||
_par_robot_path: Path = Path(".")
|
||||
_par_specs: list[ParamSpec] = []
|
||||
_par_kwargs: dict = {}
|
||||
|
||||
|
||||
def _init_worker(recording, robot_path, specs, kwargs):
|
||||
"""Initialiser run once per worker process."""
|
||||
global _par_recording, _par_robot_path, _par_specs, _par_kwargs
|
||||
_par_recording = recording
|
||||
_par_robot_path = robot_path
|
||||
_par_specs = specs
|
||||
_par_kwargs = kwargs
|
||||
|
||||
|
||||
def _eval_candidate(x_natural: np.ndarray) -> float:
|
||||
"""Evaluate a single candidate — called in worker processes."""
|
||||
return cost_function(
|
||||
x_natural,
|
||||
_par_recording,
|
||||
_par_robot_path,
|
||||
_par_specs,
|
||||
**_par_kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ── CMA-ES optimisation loop ────────────────────────────────────────
|
||||
|
||||
|
||||
def optimize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
specs: list[ParamSpec] | None = None,
|
||||
sigma0: float = 0.3,
|
||||
population_size: int = 20,
|
||||
max_generations: int = 1000,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
pendulum_scale: float = 3.0,
|
||||
window_duration: float = 0.5,
|
||||
seed: int = 42,
|
||||
preprocess_vel: bool = True,
|
||||
sg_window: int = 7,
|
||||
sg_polyorder: int = 3,
|
||||
) -> dict:
|
||||
"""Run CMA-ES optimisation and return results.
|
||||
|
||||
Motor parameters are locked from the motor-only sysid.
|
||||
Only pendulum/arm parameters are optimised.
|
||||
|
||||
Returns a dict with:
|
||||
best_params: dict[str, float]
|
||||
best_cost: float
|
||||
history: list of (generation, best_cost) tuples
|
||||
recording: str (path used)
|
||||
specs: list of param names
|
||||
motor_params: dict of locked motor params
|
||||
"""
|
||||
from cmaes import CMA
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording_path = Path(recording_path).resolve()
|
||||
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
|
||||
motor_params = dict(LOCKED_MOTOR_PARAMS)
|
||||
log.info(
|
||||
"motor_params_locked",
|
||||
n_params=len(motor_params),
|
||||
gear_avg=f"{(motor_params['actuator_gear_pos'] + motor_params['actuator_gear_neg']) / 2:.4f}",
|
||||
)
|
||||
|
||||
# Load recording.
|
||||
recording = dict(np.load(recording_path))
|
||||
|
||||
# Preprocessing: SG velocity recomputation.
|
||||
recording = _preprocess_recording(
|
||||
recording,
|
||||
preprocess_vel=preprocess_vel,
|
||||
sg_window=sg_window,
|
||||
sg_polyorder=sg_polyorder,
|
||||
)
|
||||
|
||||
n_samples = len(recording["time"])
|
||||
duration = recording["time"][-1] - recording["time"][0]
|
||||
n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
|
||||
log.info(
|
||||
"recording_loaded",
|
||||
path=str(recording_path),
|
||||
samples=n_samples,
|
||||
duration=f"{duration:.1f}s",
|
||||
window_duration=f"{window_duration}s",
|
||||
n_windows=n_windows,
|
||||
)
|
||||
|
||||
# Initial point (defaults) — normalised to [0, 1] for CMA-ES.
|
||||
lo, hi = bounds_arrays(specs)
|
||||
x0 = defaults_vector(specs)
|
||||
|
||||
# Normalise to [0, 1] for the optimizer (better conditioned).
|
||||
span = hi - lo
|
||||
span[span == 0] = 1.0 # avoid division by zero
|
||||
|
||||
def to_normed(x: np.ndarray) -> np.ndarray:
|
||||
return (x - lo) / span
|
||||
|
||||
def from_normed(x_n: np.ndarray) -> np.ndarray:
|
||||
return x_n * span + lo
|
||||
|
||||
x0_normed = to_normed(x0)
|
||||
bounds_normed = np.column_stack(
|
||||
[np.zeros(len(specs)), np.ones(len(specs))]
|
||||
)
|
||||
|
||||
optimizer = CMA(
|
||||
mean=x0_normed,
|
||||
sigma=sigma0,
|
||||
bounds=bounds_normed,
|
||||
population_size=population_size,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
best_cost = float("inf")
|
||||
best_params_vec = x0.copy()
|
||||
history: list[tuple[int, float]] = []
|
||||
|
||||
log.info(
|
||||
"cmaes_starting",
|
||||
n_params=len(specs),
|
||||
population=population_size,
|
||||
max_gens=max_generations,
|
||||
sigma0=sigma0,
|
||||
)
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
# ── Parallel evaluation setup ────────────────────────────────
|
||||
# Each candidate is independent — evaluate them in parallel using
|
||||
# a process pool. Falls back to sequential if n_workers=1.
|
||||
import multiprocessing as mp
|
||||
n_workers = max(1, mp.cpu_count() - 1) # leave 1 core free
|
||||
|
||||
eval_kwargs = dict(
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
pos_weight=pos_weight,
|
||||
vel_weight=vel_weight,
|
||||
pendulum_scale=pendulum_scale,
|
||||
window_duration=window_duration,
|
||||
motor_params=motor_params,
|
||||
)
|
||||
|
||||
log.info("parallel_workers", n_workers=n_workers)
|
||||
|
||||
# Create a persistent pool (avoids per-generation fork overhead).
|
||||
pool = None
|
||||
if n_workers > 1:
|
||||
pool = mp.Pool(
|
||||
n_workers,
|
||||
initializer=_init_worker,
|
||||
initargs=(recording, robot_path, specs, eval_kwargs),
|
||||
)
|
||||
|
||||
for gen in range(max_generations):
|
||||
# Ask all candidates first.
|
||||
candidates_normed = []
|
||||
candidates_natural = []
|
||||
for _ in range(optimizer.population_size):
|
||||
x_normed = optimizer.ask()
|
||||
x_natural = from_normed(x_normed)
|
||||
x_natural = np.clip(x_natural, lo, hi)
|
||||
candidates_normed.append(x_normed)
|
||||
candidates_natural.append(x_natural)
|
||||
|
||||
# Evaluate in parallel.
|
||||
if pool is not None:
|
||||
costs = pool.map(_eval_candidate, candidates_natural)
|
||||
else:
|
||||
costs = [cost_function(
|
||||
x, recording, robot_path, specs, **eval_kwargs
|
||||
) for x in candidates_natural]
|
||||
|
||||
solutions = list(zip(candidates_normed, costs))
|
||||
for x_natural, c in zip(candidates_natural, costs):
|
||||
if c < best_cost:
|
||||
best_cost = c
|
||||
best_params_vec = x_natural.copy()
|
||||
|
||||
optimizer.tell(solutions)
|
||||
history.append((gen, best_cost))
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
if gen % 5 == 0 or gen == max_generations - 1:
|
||||
log.info(
|
||||
"cmaes_generation",
|
||||
gen=gen,
|
||||
best_cost=f"{best_cost:.6f}",
|
||||
elapsed=f"{elapsed:.1f}s",
|
||||
gen_best=f"{min(c for _, c in solutions):.6f}",
|
||||
)
|
||||
|
||||
total_time = time.monotonic() - t0
|
||||
|
||||
# Clean up process pool.
|
||||
if pool is not None:
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
best_params = params_to_dict(best_params_vec, specs)
|
||||
|
||||
log.info(
|
||||
"cmaes_finished",
|
||||
best_cost=f"{best_cost:.6f}",
|
||||
total_time=f"{total_time:.1f}s",
|
||||
evaluations=max_generations * population_size,
|
||||
)
|
||||
|
||||
# Log parameter comparison.
|
||||
defaults = params_to_dict(defaults_vector(specs), specs)
|
||||
for name in best_params:
|
||||
d = defaults[name]
|
||||
b = best_params[name]
|
||||
change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
|
||||
log.info(
|
||||
"param_result",
|
||||
name=name,
|
||||
default=f"{d:.6g}",
|
||||
tuned=f"{b:.6g}",
|
||||
change=f"{change_pct:+.1f}%",
|
||||
)
|
||||
|
||||
return {
|
||||
"best_params": best_params,
|
||||
"best_cost": best_cost,
|
||||
"history": history,
|
||||
"recording": str(recording_path),
|
||||
"param_names": [s.name for s in specs],
|
||||
"defaults": {s.name: s.default for s in specs},
|
||||
"motor_params": motor_params,
|
||||
"preprocess_vel": preprocess_vel,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fit simulation parameters to a real-robot recording (CMA-ES)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
help="Path to robot asset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recording",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to .npz recording file",
|
||||
)
|
||||
parser.add_argument("--sigma0", type=float, default=0.3)
|
||||
parser.add_argument("--population-size", type=int, default=20)
|
||||
parser.add_argument("--max-generations", type=int, default=200)
|
||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||
parser.add_argument("--substeps", type=int, default=10)
|
||||
parser.add_argument("--pos-weight", type=float, default=1.0)
|
||||
parser.add_argument("--vel-weight", type=float, default=0.1)
|
||||
parser.add_argument("--pendulum-scale", type=float, default=3.0,
|
||||
help="Multiplier for pendulum terms relative to motor (default 3.0)")
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shooting window length in seconds (0 = open-loop, default 0.5)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--no-export",
|
||||
action="store_true",
|
||||
help="Skip exporting tuned files (results JSON only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-preprocess-vel",
|
||||
action="store_true",
|
||||
help="Skip Savitzky-Golay velocity preprocessing",
|
||||
)
|
||||
parser.add_argument("--sg-window", type=int, default=7,
|
||||
help="Savitzky-Golay window length (odd, default 7)")
|
||||
parser.add_argument("--sg-polyorder", type=int, default=3,
|
||||
help="Savitzky-Golay polynomial order (default 3)")
|
||||
parser.add_argument(
|
||||
"--param-set",
|
||||
type=str,
|
||||
default="full",
|
||||
choices=list(PARAM_SETS.keys()),
|
||||
help="Parameter set to optimize: 'reduced' (6 params, fast) or 'full' (15 params)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
specs = PARAM_SETS[args.param_set]
|
||||
|
||||
result = optimize(
|
||||
robot_path=args.robot_path,
|
||||
recording_path=args.recording,
|
||||
specs=specs,
|
||||
sigma0=args.sigma0,
|
||||
population_size=args.population_size,
|
||||
max_generations=args.max_generations,
|
||||
sim_dt=args.sim_dt,
|
||||
substeps=args.substeps,
|
||||
pos_weight=args.pos_weight,
|
||||
vel_weight=args.vel_weight,
|
||||
pendulum_scale=args.pendulum_scale,
|
||||
window_duration=args.window_duration,
|
||||
seed=args.seed,
|
||||
preprocess_vel=not args.no_preprocess_vel,
|
||||
sg_window=args.sg_window,
|
||||
sg_polyorder=args.sg_polyorder,
|
||||
)
|
||||
|
||||
# Save results JSON.
|
||||
robot_path = Path(args.robot_path).resolve()
|
||||
result_path = robot_path / "sysid_result.json"
|
||||
# Convert numpy types for JSON serialisation.
|
||||
result_json = {
|
||||
k: v for k, v in result.items() if k != "history"
|
||||
}
|
||||
result_json["history_summary"] = {
|
||||
"first_cost": result["history"][0][1] if result["history"] else None,
|
||||
"final_cost": result["history"][-1][1] if result["history"] else None,
|
||||
"generations": len(result["history"]),
|
||||
}
|
||||
result_path.write_text(json.dumps(result_json, indent=2, default=str))
|
||||
log.info("results_saved", path=str(result_path))
|
||||
|
||||
# Export tuned files unless --no-export.
|
||||
if not args.no_export:
|
||||
from src.sysid.export import export_tuned_files
|
||||
|
||||
export_tuned_files(
|
||||
robot_path=args.robot_path,
|
||||
params=result["best_params"],
|
||||
motor_params=result.get("motor_params"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
425
src/sysid/rollout.py
Normal file
425
src/sysid/rollout.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
|
||||
|
||||
Given a parameter vector and a recorded action sequence, builds a MuJoCo
|
||||
model with overridden physics parameters, replays the actions, and returns
|
||||
the simulated trajectory for comparison with the real recording.
|
||||
|
||||
This module is the inner loop of the CMA-ES optimizer: it is called once
|
||||
per candidate parameter vector per generation.
|
||||
|
||||
Motor parameters are **locked** from the unified sysid result.
|
||||
The optimizer only fits
|
||||
pendulum/arm inertial parameters, pendulum joint dynamics, and
|
||||
``ctrl_limit``. The asymmetric motor model (bias, deadzone, gear
|
||||
compensation, Coulomb + Stribeck friction, viscous damping) is applied
|
||||
via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()`` —
|
||||
the same code the training runners use, so sim == sysid by construction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from src.core.robot import ActuatorConfig, JointConfig, RobotConfig
|
||||
from src.runners.mujoco import ActuatorLimits, load_mujoco_model
|
||||
from src.sysid._urdf import patch_link_inertials
|
||||
|
||||
|
||||
# ── Locked motor parameters (from the unified sysid) ────────────────
|
||||
# These are FIXED and not optimised. They come from the unified
|
||||
# 28-param sysid run (assets/rotary_cartpole/sysid_result.json,
|
||||
# cost 0.925) — Stribeck friction + action bias + ~96 ms motor lag.
|
||||
|
||||
LOCKED_MOTOR_PARAMS: dict[str, float] = {
|
||||
"actuator_gear_pos": 0.846499,
|
||||
"actuator_gear_neg": 1.183733,
|
||||
"actuator_filter_tau": 0.096263,
|
||||
"motor_damping_pos": 0.013165,
|
||||
"motor_damping_neg": 0.015452,
|
||||
"motor_armature": 0.001676,
|
||||
"motor_frictionloss_pos": 0.014244,
|
||||
"motor_frictionloss_neg": 0.001005,
|
||||
"stribeck_friction_boost": 0.068594,
|
||||
"stribeck_vel": 5.279594,
|
||||
"motor_deadzone_pos": 0.181097,
|
||||
"motor_deadzone_neg": 0.202072,
|
||||
"action_bias": 0.056566,
|
||||
}
|
||||
|
||||
|
||||
# ── Tunable parameter specification ──────────────────────────────────
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParamSpec:
|
||||
"""Specification for a single tunable parameter."""
|
||||
|
||||
name: str
|
||||
default: float
|
||||
lower: float
|
||||
upper: float
|
||||
log_scale: bool = False # optimise in log-space (masses, inertias)
|
||||
|
||||
|
||||
# Pendulum sysid parameters — motor params are LOCKED (not here).
|
||||
# Order matters: the optimizer maps a flat vector to these specs.
|
||||
# Defaults are from the URDF exported by Fusion 360.
|
||||
ROTARY_CARTPOLE_PARAMS: list[ParamSpec] = [
|
||||
# ── Arm link (URDF) ──────────────────────────────────────────
|
||||
ParamSpec("arm_mass", 0.02110, 0.005, 0.08, log_scale=True),
|
||||
ParamSpec("arm_com_x", -0.00710, -0.03, 0.03),
|
||||
ParamSpec("arm_com_y", 0.00085, -0.02, 0.02),
|
||||
ParamSpec("arm_com_z", 0.00795, -0.02, 0.02),
|
||||
# ── Pendulum link (URDF) ─────────────────────────────────────
|
||||
ParamSpec("pendulum_mass", 0.03937, 0.010, 0.10, log_scale=True),
|
||||
ParamSpec("pendulum_com_x", 0.06025, 0.01, 0.15),
|
||||
ParamSpec("pendulum_com_y", -0.07602, -0.20, 0.0),
|
||||
ParamSpec("pendulum_com_z", -0.00346, -0.05, 0.05),
|
||||
ParamSpec("pendulum_ixx", 6.20e-05, 1e-07, 1e-03, log_scale=True),
|
||||
ParamSpec("pendulum_iyy", 3.70e-05, 1e-07, 1e-03, log_scale=True),
|
||||
ParamSpec("pendulum_izz", 7.83e-05, 1e-07, 1e-03, log_scale=True),
|
||||
ParamSpec("pendulum_ixy", -6.93e-06, -1e-03, 1e-03),
|
||||
# ── Pendulum joint dynamics ──────────────────────────────────
|
||||
ParamSpec("pendulum_damping", 0.0001, 1e-6, 0.05, log_scale=True),
|
||||
ParamSpec("pendulum_frictionloss", 0.0001, 1e-6, 0.05, log_scale=True),
|
||||
# ── Hardware realism (control pipeline) ────────────────────
|
||||
ParamSpec("ctrl_limit", 0.588, 0.45, 0.70), # MAX_MOTOR_SPEED / 255
|
||||
]
|
||||
|
||||
|
||||
# Extended set: full params + motor armature (compensates for the
|
||||
# motor-only sysid having captured arm/pendulum loading in armature).
|
||||
EXTENDED_PARAMS: list[ParamSpec] = ROTARY_CARTPOLE_PARAMS + [
|
||||
ParamSpec("motor_armature", 0.00277, 0.0005, 0.02, log_scale=True),
|
||||
]
|
||||
|
||||
|
||||
# Reduced set: only the 6 most impactful pendulum parameters.
|
||||
# Good for a fast first pass — converges in ~50 generations.
|
||||
REDUCED_PARAMS: list[ParamSpec] = [
|
||||
ParamSpec("pendulum_mass", 0.03937, 0.010, 0.10, log_scale=True),
|
||||
ParamSpec("pendulum_com_x", 0.06025, 0.01, 0.15),
|
||||
ParamSpec("pendulum_com_y", -0.07602, -0.20, 0.0),
|
||||
ParamSpec("pendulum_ixx", 6.20e-05, 1e-07, 1e-03, log_scale=True),
|
||||
ParamSpec("pendulum_damping", 0.0001, 1e-6, 0.05, log_scale=True),
|
||||
ParamSpec("pendulum_frictionloss", 0.0001, 1e-6, 0.05, log_scale=True),
|
||||
]
|
||||
|
||||
|
||||
PARAM_SETS: dict[str, list[ParamSpec]] = {
|
||||
"full": ROTARY_CARTPOLE_PARAMS,
|
||||
"extended": EXTENDED_PARAMS,
|
||||
"reduced": REDUCED_PARAMS,
|
||||
}
|
||||
|
||||
|
||||
def params_to_dict(
|
||||
values: np.ndarray, specs: list[ParamSpec] | None = None
|
||||
) -> dict[str, float]:
|
||||
"""Convert a flat parameter vector to a named dict."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
return {s.name: float(values[i]) for i, s in enumerate(specs)}
|
||||
|
||||
|
||||
def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
|
||||
"""Return the default parameter vector (in natural space)."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
return np.array([s.default for s in specs], dtype=np.float64)
|
||||
|
||||
|
||||
def bounds_arrays(
|
||||
specs: list[ParamSpec] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (lower, upper) bound arrays."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
lo = np.array([s.lower for s in specs], dtype=np.float64)
|
||||
hi = np.array([s.upper for s in specs], dtype=np.float64)
|
||||
return lo, hi
|
||||
|
||||
|
||||
# ── MuJoCo model building with parameter overrides ──────────────────
|
||||
|
||||
|
||||
def _build_model(
|
||||
robot_path: Path,
|
||||
params: dict[str, float],
|
||||
motor_params: dict[str, float] | None = None,
|
||||
) -> tuple[mujoco.MjModel, ActuatorConfig]:
|
||||
"""Build a MuJoCo model with sysid overrides.
|
||||
|
||||
Returns (model, actuator) — use ``actuator.transform_ctrl()`` and
|
||||
``actuator.compute_motor_force()`` in the rollout loop.
|
||||
"""
|
||||
if motor_params is None:
|
||||
motor_params = LOCKED_MOTOR_PARAMS
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
|
||||
# ── Patch URDF inertial parameters to a temp file ────────────
|
||||
robot_yaml = yaml.safe_load((robot_path / "robot.yaml").read_text())
|
||||
urdf_path = robot_path / robot_yaml["urdf"]
|
||||
|
||||
tree = ET.parse(urdf_path)
|
||||
patch_link_inertials(tree.getroot(), params)
|
||||
|
||||
fd, tmp_urdf = tempfile.mkstemp(
|
||||
suffix=".urdf", prefix="_sysid_", dir=str(robot_path),
|
||||
)
|
||||
os.close(fd)
|
||||
tmp_urdf_path = Path(tmp_urdf)
|
||||
tree.write(str(tmp_urdf_path), xml_declaration=True, encoding="unicode")
|
||||
|
||||
# ── Build RobotConfig with full motor sysid values ───────────
|
||||
gear_pos = motor_params.get("actuator_gear_pos", 0.424182)
|
||||
gear_neg = motor_params.get("actuator_gear_neg", 0.425031)
|
||||
motor_armature = params.get(
|
||||
"motor_armature",
|
||||
motor_params.get("motor_armature", 0.00277342),
|
||||
)
|
||||
pend_damping = params.get("pendulum_damping", 0.0001)
|
||||
pend_frictionloss = params.get("pendulum_frictionloss", 0.0001)
|
||||
|
||||
act_cfg = robot_yaml["actuators"][0]
|
||||
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
|
||||
|
||||
# The fitted ctrl_limit overrides the YAML ctrl_range so the rollout
|
||||
# saturates at exactly the identified PWM bound.
|
||||
if "ctrl_limit" in params:
|
||||
ctrl_lo, ctrl_hi = -params["ctrl_limit"], params["ctrl_limit"]
|
||||
|
||||
actuator = ActuatorConfig(
|
||||
joint=act_cfg["joint"],
|
||||
type="motor",
|
||||
gear=(gear_pos, gear_neg),
|
||||
ctrl_range=(ctrl_lo, ctrl_hi),
|
||||
deadzone=(
|
||||
motor_params.get("motor_deadzone_pos", 0.181),
|
||||
motor_params.get("motor_deadzone_neg", 0.202),
|
||||
),
|
||||
damping=(
|
||||
motor_params.get("motor_damping_pos", 0.013),
|
||||
motor_params.get("motor_damping_neg", 0.015),
|
||||
),
|
||||
frictionloss=(
|
||||
motor_params.get("motor_frictionloss_pos", 0.014),
|
||||
motor_params.get("motor_frictionloss_neg", 0.001),
|
||||
),
|
||||
filter_tau=motor_params.get("actuator_filter_tau", 0.096),
|
||||
viscous_quadratic=motor_params.get("viscous_quadratic", 0.0),
|
||||
back_emf_gain=motor_params.get("back_emf_gain", 0.0),
|
||||
stribeck_friction_boost=motor_params.get("stribeck_friction_boost", 0.0),
|
||||
stribeck_vel=motor_params.get("stribeck_vel", 2.0),
|
||||
action_bias=motor_params.get("action_bias", 0.0),
|
||||
)
|
||||
|
||||
robot = RobotConfig(
|
||||
urdf_path=tmp_urdf_path,
|
||||
actuators=[actuator],
|
||||
joints={
|
||||
"motor_joint": JointConfig(
|
||||
damping=0.0,
|
||||
armature=motor_armature,
|
||||
frictionloss=0.0,
|
||||
),
|
||||
"pendulum_joint": JointConfig(
|
||||
damping=pend_damping,
|
||||
frictionloss=pend_frictionloss,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
model = load_mujoco_model(robot)
|
||||
finally:
|
||||
tmp_urdf_path.unlink(missing_ok=True)
|
||||
|
||||
return model, actuator
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ── Simulation rollout ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def rollout(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
actions: np.ndarray,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
motor_params: dict[str, float] | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Replay recorded actions in MuJoCo with overridden parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : asset directory
|
||||
params : named parameter overrides (pendulum/arm only)
|
||||
actions : (N,) normalised actions [-1, 1] from the recording
|
||||
sim_dt : MuJoCo physics timestep
|
||||
substeps : physics substeps per control step
|
||||
motor_params : locked motor params (default: LOCKED_MOTOR_PARAMS)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: motor_angle, motor_vel, pendulum_angle, pendulum_vel
|
||||
Each is an (N,) numpy array of simulated values.
|
||||
"""
|
||||
if motor_params is None:
|
||||
motor_params = LOCKED_MOTOR_PARAMS
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
model, actuator = _build_model(robot_path, params, motor_params)
|
||||
model.opt.timestep = sim_dt
|
||||
data = mujoco.MjData(model)
|
||||
mujoco.mj_resetData(model, data)
|
||||
|
||||
n = len(actions)
|
||||
|
||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||
|
||||
limits = ActuatorLimits(model)
|
||||
|
||||
for i in range(n):
|
||||
# transform_ctrl clips to the (fitted) ctrl_range internally.
|
||||
ctrl = actuator.transform_ctrl(float(actions[i]))
|
||||
data.ctrl[0] = ctrl
|
||||
|
||||
for _ in range(substeps):
|
||||
limits.enforce(model, data)
|
||||
data.qfrc_applied[0] = actuator.compute_motor_force(data.qvel[0], ctrl)
|
||||
mujoco.mj_step(model, data)
|
||||
|
||||
sim_motor_angle[i] = data.qpos[0]
|
||||
sim_pend_angle[i] = data.qpos[1]
|
||||
sim_motor_vel[i] = data.qvel[0]
|
||||
sim_pend_vel[i] = data.qvel[1]
|
||||
|
||||
return {
|
||||
"motor_angle": sim_motor_angle,
|
||||
"motor_vel": sim_motor_vel,
|
||||
"pendulum_angle": sim_pend_angle,
|
||||
"pendulum_vel": sim_pend_vel,
|
||||
}
|
||||
|
||||
|
||||
def windowed_rollout(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
recording: dict[str, np.ndarray],
|
||||
window_duration: float = 0.5,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
motor_params: dict[str, float] | None = None,
|
||||
) -> dict[str, np.ndarray | float]:
|
||||
"""Multiple-shooting rollout — split recording into short windows.
|
||||
|
||||
For each window:
|
||||
1. Initialize MuJoCo state from the real qpos/qvel at the window start.
|
||||
2. Replay the recorded actions within the window.
|
||||
3. Record the simulated output.
|
||||
|
||||
Motor dynamics (asymmetric friction, damping, back-EMF, etc.) are
|
||||
applied via qfrc_applied using the locked motor sysid parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : asset directory
|
||||
params : named parameter overrides (pendulum/arm only)
|
||||
recording : dict with keys time, action, motor_angle, motor_vel,
|
||||
pendulum_angle, pendulum_vel (all 1D arrays of length N)
|
||||
window_duration : length of each shooting window in seconds
|
||||
sim_dt : MuJoCo physics timestep
|
||||
substeps : physics substeps per control step
|
||||
motor_params : locked motor params (default: LOCKED_MOTOR_PARAMS)
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with:
|
||||
motor_angle, motor_vel, pendulum_angle, pendulum_vel — (N,) arrays
|
||||
(stitched from per-window simulations)
|
||||
n_windows — number of windows used
|
||||
"""
|
||||
if motor_params is None:
|
||||
motor_params = LOCKED_MOTOR_PARAMS
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
model, actuator = _build_model(robot_path, params, motor_params)
|
||||
model.opt.timestep = sim_dt
|
||||
data = mujoco.MjData(model)
|
||||
|
||||
times = recording["time"]
|
||||
actions = recording["action"]
|
||||
real_motor = recording["motor_angle"]
|
||||
real_motor_vel = recording["motor_vel"]
|
||||
real_pend = recording["pendulum_angle"]
|
||||
real_pend_vel = recording["pendulum_vel"]
|
||||
n = len(actions)
|
||||
|
||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||
|
||||
limits = ActuatorLimits(model)
|
||||
|
||||
t0 = times[0]
|
||||
t_end = times[-1]
|
||||
window_starts: list[int] = []
|
||||
current_t = t0
|
||||
while current_t < t_end:
|
||||
idx = int(np.searchsorted(times, current_t))
|
||||
idx = min(idx, n - 1)
|
||||
window_starts.append(idx)
|
||||
current_t += window_duration
|
||||
|
||||
n_windows = len(window_starts)
|
||||
|
||||
for w, w_start in enumerate(window_starts):
|
||||
w_end = window_starts[w + 1] if w + 1 < n_windows else n
|
||||
|
||||
mujoco.mj_resetData(model, data)
|
||||
data.qpos[0] = real_motor[w_start]
|
||||
data.qpos[1] = real_pend[w_start]
|
||||
data.qvel[0] = real_motor_vel[w_start]
|
||||
data.qvel[1] = real_pend_vel[w_start]
|
||||
data.ctrl[:] = 0.0
|
||||
mujoco.mj_forward(model, data)
|
||||
|
||||
for i in range(w_start, w_end):
|
||||
# transform_ctrl clips to the (fitted) ctrl_range internally.
|
||||
ctrl = actuator.transform_ctrl(float(actions[i]))
|
||||
data.ctrl[0] = ctrl
|
||||
|
||||
for _ in range(substeps):
|
||||
limits.enforce(model, data)
|
||||
data.qfrc_applied[0] = actuator.compute_motor_force(data.qvel[0], ctrl)
|
||||
mujoco.mj_step(model, data)
|
||||
|
||||
sim_motor_angle[i] = data.qpos[0]
|
||||
sim_pend_angle[i] = data.qpos[1]
|
||||
sim_motor_vel[i] = data.qvel[0]
|
||||
sim_pend_vel[i] = data.qvel[1]
|
||||
|
||||
return {
|
||||
"motor_angle": sim_motor_angle,
|
||||
"motor_vel": sim_motor_vel,
|
||||
"pendulum_angle": sim_pend_angle,
|
||||
"pendulum_vel": sim_pend_vel,
|
||||
"n_windows": n_windows,
|
||||
}
|
||||
248
src/sysid/visualize.py
Normal file
248
src/sysid/visualize.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Visualise system identification results — real vs simulated trajectories.
|
||||
|
||||
Loads a recording and runs simulation with both the original and tuned
|
||||
parameters, then plots a 4-panel comparison (motor angle, motor vel,
|
||||
pendulum angle, pendulum vel) over time.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||
|
||||
# Also compare with tuned parameters:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording <file>.npz \
|
||||
--result assets/rotary_cartpole/sysid_result.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def _run_sim(
|
||||
robot_path: Path,
|
||||
params: dict[str, float],
|
||||
recording: dict[str, np.ndarray],
|
||||
window_duration: float,
|
||||
sim_dt: float,
|
||||
substeps: int,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Run windowed or open-loop rollout depending on window_duration."""
|
||||
from src.sysid.rollout import rollout, windowed_rollout
|
||||
|
||||
if window_duration > 0:
|
||||
return windowed_rollout(
|
||||
robot_path=robot_path, params=params, recording=recording,
|
||||
window_duration=window_duration, sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
return rollout(
|
||||
robot_path=robot_path, params=params, actions=recording["action"],
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
|
||||
def visualize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
result_path: str | Path | None = None,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
window_duration: float = 0.5,
|
||||
save_path: str | Path | None = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""Generate comparison plot."""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from src.sysid.rollout import (
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
)
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording = dict(np.load(recording_path))
|
||||
|
||||
sim_kwargs = dict(
|
||||
robot_path=robot_path, recording=recording,
|
||||
window_duration=window_duration, sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
t = recording["time"]
|
||||
actions = recording["action"]
|
||||
|
||||
# ── Simulate with default parameters ─────────────────────────
|
||||
default_params = params_to_dict(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
|
||||
)
|
||||
log.info("simulating_default_params", windowed=window_duration > 0)
|
||||
sim_default = _run_sim(params=default_params, **sim_kwargs)
|
||||
|
||||
# ── Simulate with tuned parameters (if available) ────────────
|
||||
# Resolve result path (explicit or auto-detect).
|
||||
if result_path is None:
|
||||
auto = robot_path / "sysid_result.json"
|
||||
if auto.exists():
|
||||
result_path = auto
|
||||
|
||||
sim_tuned = None
|
||||
tuned_cost = None
|
||||
if result_path is not None:
|
||||
result_path = Path(result_path)
|
||||
if result_path.exists():
|
||||
result = json.loads(result_path.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("simulating_tuned_params", cost=tuned_cost)
|
||||
sim_tuned = _run_sim(params=tuned_params, **sim_kwargs)
|
||||
else:
|
||||
log.warning("result_file_not_found", path=str(result_path))
|
||||
|
||||
# ── Plot ─────────────────────────────────────────────────────
|
||||
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
|
||||
|
||||
channels = [
|
||||
("motor_angle", "Motor Angle (rad)"),
|
||||
("motor_vel", "Motor Velocity (rad/s)"),
|
||||
("pendulum_angle", "Pendulum Angle (rad)"),
|
||||
("pendulum_vel", "Pendulum Velocity (rad/s)"),
|
||||
]
|
||||
|
||||
for ax, (key, ylabel) in zip(axes[:4], channels):
|
||||
real = recording[key]
|
||||
|
||||
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
|
||||
ax.plot(
|
||||
t,
|
||||
sim_default[key],
|
||||
"--",
|
||||
color="#d62728",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (original)",
|
||||
)
|
||||
if sim_tuned is not None:
|
||||
ax.plot(
|
||||
t,
|
||||
sim_tuned[key],
|
||||
"--",
|
||||
color="#2ca02c",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (tuned)",
|
||||
)
|
||||
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.legend(loc="upper right", fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Action plot (bottom panel).
|
||||
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
|
||||
axes[4].set_ylabel("Action (norm)")
|
||||
axes[4].set_xlabel("Time (s)")
|
||||
axes[4].grid(True, alpha=0.3)
|
||||
axes[4].set_ylim(-1.1, 1.1)
|
||||
|
||||
# Title with cost info.
|
||||
title = "System Identification — Real vs Simulated Trajectories"
|
||||
if tuned_cost is not None:
|
||||
# Compute original cost for comparison.
|
||||
from src.sysid.optimize import cost_function
|
||||
|
||||
orig_cost = cost_function(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS),
|
||||
recording,
|
||||
robot_path,
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
window_duration=window_duration,
|
||||
)
|
||||
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||
title += f" ({improvement:+.1f}%)"
|
||||
|
||||
fig.suptitle(title, fontsize=12)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
save_path = Path(save_path)
|
||||
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
||||
log.info("figure_saved", path=str(save_path))
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualise system identification results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recording",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to .npz recording file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to sysid_result.json (auto-detected if omitted)",
|
||||
)
|
||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||
parser.add_argument("--substeps", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shooting window length in seconds (0 = open-loop)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Save figure to this path (PNG, PDF, …)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-show",
|
||||
action="store_true",
|
||||
help="Don't show interactive window (useful for CI)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize(
|
||||
robot_path=args.robot_path,
|
||||
recording_path=args.recording,
|
||||
result_path=args.result,
|
||||
sim_dt=args.sim_dt,
|
||||
substeps=args.substeps,
|
||||
window_duration=args.window_duration,
|
||||
save_path=args.save,
|
||||
show=not args.no_show,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,19 +4,24 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
import tqdm
|
||||
from clearml import Logger
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
from src.core.runner import BaseRunner
|
||||
from clearml import Task, Logger
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from src.models.mlp import SharedMLP
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerConfig:
|
||||
# PPO
|
||||
rollout_steps: int = 2048
|
||||
learning_epochs: int = 8
|
||||
mini_batches: int = 4
|
||||
@@ -26,33 +31,41 @@ class TrainerConfig:
|
||||
clip_ratio: float = 0.2
|
||||
value_loss_scale: float = 0.5
|
||||
entropy_loss_scale: float = 0.01
|
||||
kl_threshold: float = 0.01 # KL-adaptive LR target; 0 = fixed LR
|
||||
|
||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||
|
||||
# Policy
|
||||
initial_log_std: float = 0.5 # initial exploration noise
|
||||
min_log_std: float = -2.0 # minimum exploration noise
|
||||
max_log_std: float = 2.0 # maximum exploration noise (2.0 ≈ σ=7.4)
|
||||
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
checkpoint_interval: int = 50_000
|
||||
|
||||
# Video recording
|
||||
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
|
||||
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
|
||||
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
|
||||
# Video recording (uploaded to ClearML)
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||
|
||||
clearml_project: str | None = None
|
||||
clearml_task: str | None = None
|
||||
# History encoder (implicit adaptation). The window size comes from
|
||||
# the runner (runner.history_length) — single source of truth.
|
||||
embedding_dim: int = 32 # history encoder output dimension
|
||||
|
||||
|
||||
# ── Video-recording trainer ──────────────────────────────────────────
|
||||
|
||||
class VideoRecordingTrainer(SequentialTrainer):
|
||||
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
|
||||
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
|
||||
|
||||
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
||||
super().__init__(env=env, agents=agents, cfg=cfg)
|
||||
self._trainer_config = trainer_config
|
||||
self._tcfg = trainer_config
|
||||
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
||||
|
||||
def single_agent_train(self) -> None:
|
||||
"""Override to add periodic video recording."""
|
||||
assert self.num_simultaneous_agents == 1
|
||||
assert self.env.num_agents == 1
|
||||
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
|
||||
|
||||
states, infos = self.env.reset()
|
||||
|
||||
@@ -61,26 +74,17 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
disable=self.disable_progressbar,
|
||||
file=sys.stdout,
|
||||
):
|
||||
# Pre-interaction
|
||||
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
||||
|
||||
if not self.headless:
|
||||
self.env.render()
|
||||
|
||||
self.agents.record_transition(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
infos=infos,
|
||||
timestep=timestep,
|
||||
timesteps=self.timesteps,
|
||||
states=states, actions=actions, rewards=rewards,
|
||||
next_states=next_states, terminated=terminated,
|
||||
truncated=truncated, infos=infos,
|
||||
timestep=timestep, timesteps=self.timesteps,
|
||||
)
|
||||
|
||||
if self.environment_info in infos:
|
||||
@@ -90,7 +94,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
|
||||
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
# Reset environments
|
||||
# Auto-reset for multi-env; single-env resets manually
|
||||
if self.env.num_envs > 1:
|
||||
states = next_states
|
||||
else:
|
||||
@@ -100,111 +104,125 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
else:
|
||||
states = next_states
|
||||
|
||||
# Record video at intervals
|
||||
cfg = self._trainer_config
|
||||
# Periodic video recording. Recording steps the (shared) envs,
|
||||
# so it returns a freshly reset observation — the training loop
|
||||
# MUST continue from it, otherwise the recorded transitions no
|
||||
# longer match the actual env state.
|
||||
if (
|
||||
cfg
|
||||
and cfg.record_video_every > 0
|
||||
and (timestep + 1) % cfg.record_video_every == 0
|
||||
self._tcfg
|
||||
and self._tcfg.record_video_every > 0
|
||||
and (timestep + 1) % self._tcfg.record_video_every == 0
|
||||
):
|
||||
self._record_video(timestep + 1)
|
||||
fresh_states = self._record_video(timestep + 1)
|
||||
if fresh_states is not None:
|
||||
states = fresh_states
|
||||
|
||||
def _get_video_fps(self) -> int:
|
||||
"""Derive video fps from the simulation rate, or use configured value."""
|
||||
cfg = self._trainer_config
|
||||
if cfg.record_video_fps > 0:
|
||||
return cfg.record_video_fps
|
||||
# Auto-derive from runner's simulation parameters
|
||||
runner = self.env
|
||||
dt = getattr(runner.config, "dt", 0.02)
|
||||
substeps = getattr(runner.config, "substeps", 1)
|
||||
# ── helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _get_fps(self) -> int:
|
||||
if self._tcfg and self._tcfg.record_video_fps > 0:
|
||||
return self._tcfg.record_video_fps
|
||||
dt = getattr(self.env.config, "dt", 0.02)
|
||||
substeps = getattr(self.env.config, "substeps", 1)
|
||||
# SerialRunner has dt but no substeps — dt *is* the control period.
|
||||
return max(1, int(round(1.0 / (dt * substeps))))
|
||||
|
||||
def _record_video(self, timestep: int) -> None:
|
||||
"""Record evaluation episodes and upload to ClearML."""
|
||||
def _record_video(self, timestep: int) -> torch.Tensor | None:
|
||||
"""Record an eval episode and upload it to ClearML.
|
||||
|
||||
Returns the freshly reset observation the training loop should
|
||||
continue from (the recording steps the shared envs), or ``None``
|
||||
if even the final reset failed.
|
||||
"""
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
except ImportError:
|
||||
iio = None
|
||||
|
||||
# Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
|
||||
# failure crash training — log it and skip the video.
|
||||
if iio is not None:
|
||||
try:
|
||||
import imageio as iio
|
||||
except ImportError:
|
||||
return
|
||||
fps = self._get_fps()
|
||||
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
frames: list[np.ndarray] = []
|
||||
|
||||
cfg = self._trainer_config
|
||||
fps = self._get_video_fps()
|
||||
min_frames = int(cfg.record_video_min_seconds * fps)
|
||||
max_frames = min_frames * 3 # hard cap to prevent runaway recording
|
||||
frames: list[np.ndarray] = []
|
||||
|
||||
while len(frames) < min_frames and len(frames) < max_frames:
|
||||
obs, _ = self.env.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
while not done and steps < max_episode_steps:
|
||||
obs, _ = self.env.reset()
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
frame = self.env.render(mode="rgb_array")
|
||||
if frame is not None:
|
||||
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
|
||||
done = (terminated | truncated).any().item()
|
||||
steps += 1
|
||||
if len(frames) >= max_frames:
|
||||
break
|
||||
for _ in range(max_steps):
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
|
||||
if frames:
|
||||
video_path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(video_path, frames, fps=fps)
|
||||
frame = self.env.render()
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
logger = Logger.current_logger()
|
||||
if logger:
|
||||
logger.report_media(
|
||||
title="Training Video",
|
||||
series=f"step_{timestep}",
|
||||
local_path=video_path,
|
||||
iteration=timestep,
|
||||
)
|
||||
if (terminated | truncated).any().item():
|
||||
break
|
||||
|
||||
# Reset back to training state after recording
|
||||
self.env.reset()
|
||||
if frames:
|
||||
path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(path, frames, fps=fps)
|
||||
|
||||
logger = Logger.current_logger()
|
||||
if logger:
|
||||
logger.report_media(
|
||||
"Training Video", f"step_{timestep}",
|
||||
local_path=path, iteration=timestep,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("video_recording_failed", timestep=timestep, error=str(exc))
|
||||
|
||||
# Always leave the envs freshly reset and hand the new observation
|
||||
# back to the training loop.
|
||||
try:
|
||||
with torch.no_grad():
|
||||
states, _ = self.env.reset()
|
||||
return states
|
||||
except Exception as exc:
|
||||
log.warning("post_video_reset_failed", timestep=timestep, error=str(exc))
|
||||
return None
|
||||
|
||||
|
||||
# ── Main trainer ─────────────────────────────────────────────────────
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
|
||||
self._init_clearml()
|
||||
self._init_agent()
|
||||
|
||||
def _init_clearml(self) -> None:
|
||||
if self.config.clearml_project and self.config.clearml_task:
|
||||
self.clearml_task = Task.init(
|
||||
project_name=self.config.clearml_project,
|
||||
task_name=self.config.clearml_task,
|
||||
)
|
||||
else:
|
||||
self.clearml_task = None
|
||||
|
||||
def _init_agent(self) -> None:
|
||||
device: torch.device = self.runner.device
|
||||
obs_space: spaces.Space = self.runner.observation_space
|
||||
act_space: spaces.Space = self.runner.action_space
|
||||
num_envs: int = self.runner.num_envs
|
||||
device = self.runner.device
|
||||
obs_space = self.runner.observation_space
|
||||
act_space = self.runner.action_space
|
||||
|
||||
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
|
||||
self.memory = RandomMemory(
|
||||
memory_size=self.config.rollout_steps,
|
||||
num_envs=self.runner.num_envs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.model: SharedMLP = SharedMLP(
|
||||
# Determine raw obs dim (without history augmentation) and the
|
||||
# history window size — both come from the runner so the model
|
||||
# always matches the observation layout it produces.
|
||||
raw_obs_dim = self.runner.env.observation_space.shape[0]
|
||||
history_length = getattr(self.runner.config, "history_length", 0)
|
||||
|
||||
self.model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=self.config.hidden_sizes,
|
||||
initial_log_std=self.config.initial_log_std,
|
||||
min_log_std=self.config.min_log_std,
|
||||
max_log_std=self.config.max_log_std,
|
||||
history_length=history_length,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=self.config.embedding_dim,
|
||||
)
|
||||
|
||||
models = {
|
||||
"policy": self.model,
|
||||
"value": self.model,
|
||||
}
|
||||
models = {"policy": self.model, "value": self.model}
|
||||
|
||||
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
||||
agent_cfg.update({
|
||||
@@ -217,9 +235,28 @@ class Trainer:
|
||||
"ratio_clip": self.config.clip_ratio,
|
||||
"value_loss_scale": self.config.value_loss_scale,
|
||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||
# Truncation (time limit) must bootstrap from the value function;
|
||||
# without this the value target is biased at every max_steps cut.
|
||||
"time_limit_bootstrap": True,
|
||||
"state_preprocessor": RunningStandardScaler,
|
||||
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
||||
"value_preprocessor": RunningStandardScaler,
|
||||
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
||||
})
|
||||
if self.config.kl_threshold > 0:
|
||||
from skrl.resources.schedulers.torch import KLAdaptiveLR
|
||||
agent_cfg["learning_rate_scheduler"] = KLAdaptiveLR
|
||||
agent_cfg["learning_rate_scheduler_kwargs"] = {
|
||||
"kl_threshold": self.config.kl_threshold,
|
||||
}
|
||||
# Wire up logging frequency: write_interval is in timesteps.
|
||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||
self.config.checkpoint_interval, self.config.rollout_steps
|
||||
)
|
||||
|
||||
self.agent: PPO = PPO(
|
||||
self.agent = PPO(
|
||||
models=models,
|
||||
memory=self.memory,
|
||||
observation_space=obs_space,
|
||||
@@ -238,6 +275,4 @@ class Trainer:
|
||||
trainer.train()
|
||||
|
||||
def close(self) -> None:
|
||||
self.runner.close()
|
||||
if self.clearml_task:
|
||||
self.clearml_task.close()
|
||||
self.runner.close()
|
||||
7
tests/conftest.py
Normal file
7
tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Make `src.*` importable regardless of pytest invocation directory.
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
79
tests/test_reward.py
Normal file
79
tests/test_reward.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Reward design tests — balancing must strictly dominate spinning."""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.envs.rotary_cartpole import (
|
||||
RotaryCartPoleConfig,
|
||||
RotaryCartPoleEnv,
|
||||
RotaryCartPoleState,
|
||||
)
|
||||
|
||||
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||
|
||||
|
||||
def _env() -> RotaryCartPoleEnv:
|
||||
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
|
||||
|
||||
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
|
||||
t = lambda v: torch.tensor([float(v)])
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=t(motor), motor_vel=t(motor_vel),
|
||||
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
|
||||
)
|
||||
|
||||
|
||||
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
|
||||
a = torch.tensor([[float(action)]])
|
||||
pa = torch.tensor([[float(prev_action)]])
|
||||
return float(env.compute_rewards(state, a, pa)[0])
|
||||
|
||||
|
||||
def test_balancing_beats_spinning_through_upright():
|
||||
env = _env()
|
||||
balanced = _state(pend=math.pi, pend_vel=0.0)
|
||||
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
|
||||
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
|
||||
|
||||
|
||||
def test_average_spin_cycle_reward_below_balance():
|
||||
"""Mean reward over a full revolution at high speed << balanced reward."""
|
||||
env = _env()
|
||||
angles = torch.linspace(0, 2 * math.pi, 32)
|
||||
spin_rewards = [
|
||||
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
|
||||
]
|
||||
mean_spin = sum(spin_rewards) / len(spin_rewards)
|
||||
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
|
||||
assert balanced > 3.0 * mean_spin
|
||||
|
||||
|
||||
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
|
||||
env = _env()
|
||||
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
|
||||
assert _reward(env, over_limit) == -10.0
|
||||
assert bool(env.compute_terminations(over_limit)[0])
|
||||
|
||||
|
||||
def test_action_rate_penalty_reduces_reward():
|
||||
env = _env()
|
||||
s = _state(pend=math.pi)
|
||||
smooth = _reward(env, s, action=0.5, prev_action=0.5)
|
||||
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
|
||||
assert smooth > jerky
|
||||
assert smooth - jerky == pytest.approx(
|
||||
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_initial_state_ranges_widen_pendulum_only():
|
||||
env = _env()
|
||||
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
|
||||
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
|
||||
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()
|
||||
125
tests/test_robot_config.py
Normal file
125
tests/test_robot_config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Robot config loading + motor model unit tests."""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.core.robot import ActuatorConfig, load_robot_config
|
||||
|
||||
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
|
||||
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_canonical_robot_yaml_loads_full_motor_model():
|
||||
robot = load_robot_config(ROBOT_DIR)
|
||||
act = robot.actuators[0]
|
||||
|
||||
assert act.has_motor_model
|
||||
# Tuned (unified sysid) values must survive the round-trip.
|
||||
assert act.filter_tau == pytest.approx(0.096263)
|
||||
assert act.stribeck_friction_boost == pytest.approx(0.068594)
|
||||
assert act.stribeck_vel == pytest.approx(5.279594)
|
||||
assert act.action_bias == pytest.approx(0.056566)
|
||||
assert act.gear == pytest.approx((0.846499, 1.183733))
|
||||
|
||||
|
||||
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
|
||||
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
|
||||
(tmp_path / "robot.yaml").write_text(
|
||||
"urdf: dummy.urdf\n"
|
||||
"actuators:\n"
|
||||
" - joint: j\n"
|
||||
" type: motor\n"
|
||||
" gear: [1.0, 1.0]\n"
|
||||
" some_future_field: 42\n"
|
||||
)
|
||||
robot = load_robot_config(tmp_path) # must not raise
|
||||
assert robot.actuators[0].joint == "j"
|
||||
|
||||
|
||||
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def act() -> ActuatorConfig:
|
||||
return ActuatorConfig(
|
||||
joint="m",
|
||||
gear=(0.8, 1.2),
|
||||
ctrl_range=(-0.6, 0.6),
|
||||
deadzone=(0.15, 0.20),
|
||||
frictionloss=(0.014, 0.001),
|
||||
damping=(0.013, 0.015),
|
||||
stribeck_friction_boost=0.07,
|
||||
stribeck_vel=5.0,
|
||||
action_bias=0.05,
|
||||
)
|
||||
|
||||
|
||||
def test_transform_ctrl_clips_to_ctrl_range(act):
|
||||
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
|
||||
out = act.transform_ctrl(1.0)
|
||||
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
|
||||
|
||||
|
||||
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
|
||||
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
|
||||
assert act.transform_ctrl(0.05) == 0.0
|
||||
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
|
||||
assert act.transform_ctrl(-0.15) == 0.0
|
||||
|
||||
|
||||
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
|
||||
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
|
||||
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
|
||||
assert pos == pytest.approx(0.55 * 0.8)
|
||||
assert neg == pytest.approx(-0.45 * 1.2)
|
||||
|
||||
|
||||
def test_transform_action_matches_transform_ctrl_elementwise(act):
|
||||
vals = torch.linspace(-1.2, 1.2, 49)
|
||||
batched = act.transform_action(vals.clone())
|
||||
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
|
||||
assert torch.allclose(batched, scalar, atol=1e-6)
|
||||
|
||||
|
||||
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
|
||||
|
||||
|
||||
def test_friction_opposes_motion(act):
|
||||
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
|
||||
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
|
||||
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
|
||||
|
||||
|
||||
def test_stribeck_boost_decays_with_speed(act):
|
||||
"""Friction torque magnitude (minus damping) is higher near standstill."""
|
||||
no_strb = ActuatorConfig(
|
||||
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||
damping=(0.0, 0.0),
|
||||
)
|
||||
with_strb = ActuatorConfig(
|
||||
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||
damping=(0.0, 0.0),
|
||||
stribeck_friction_boost=0.07, stribeck_vel=5.0,
|
||||
)
|
||||
v_slow, v_fast = 0.1, 50.0
|
||||
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
|
||||
no_strb.compute_motor_force(v_slow, 0.0))
|
||||
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
|
||||
no_strb.compute_motor_force(v_fast, 0.0))
|
||||
|
||||
assert extra_slow == pytest.approx(
|
||||
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
|
||||
assert extra_fast < extra_slow
|
||||
assert extra_fast == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
|
||||
def test_friction_scale_dr_multiplies_friction(act):
|
||||
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
|
||||
# damping_scale=0 isolates the friction term
|
||||
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
|
||||
assert doubled == pytest.approx(2.0 * base)
|
||||
173
tests/test_runner.py
Normal file
173
tests/test_runner.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Runner integration tests — DR, history, action delay, init randomization.
|
||||
|
||||
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
|
||||
is skipped when JAX isn't installed.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||
|
||||
DR = {
|
||||
"qpos_noise_std": 0.01,
|
||||
"qvel_noise_std": 0.5,
|
||||
"action_delay_steps": [0, 2],
|
||||
"friction_scale": [0.6, 1.6],
|
||||
"damping_scale": [0.6, 1.6],
|
||||
"torque_scale": [0.85, 1.15],
|
||||
}
|
||||
|
||||
|
||||
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
|
||||
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
cfg = MuJoCoRunnerConfig(
|
||||
num_envs=num_envs,
|
||||
device="cpu",
|
||||
history_length=history_length,
|
||||
domain_rand=domain_rand or {},
|
||||
)
|
||||
return MuJoCoRunner(env=env, config=cfg)
|
||||
|
||||
|
||||
# ── Observation layout ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_obs_is_raw_plus_history():
|
||||
runner = _runner(num_envs=2, history_length=3)
|
||||
raw_dim = runner.env.observation_space.shape[0] # 6
|
||||
step_dim = raw_dim + 1 # + action
|
||||
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
|
||||
|
||||
obs, _ = runner.reset()
|
||||
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||
# Fresh history must be zero.
|
||||
assert torch.all(obs[:, raw_dim:] == 0)
|
||||
|
||||
actions = torch.full((2, 1), 0.3)
|
||||
obs, rewards, term, trunc, info = runner.step(actions)
|
||||
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||
assert rewards.shape == (2, 1)
|
||||
# Newest history slot holds the commanded action.
|
||||
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_no_history_keeps_plain_obs():
|
||||
runner = _runner(num_envs=2, history_length=0)
|
||||
assert runner.observation_space.shape[0] == 6
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── Domain randomization ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_dr_scales_sampled_within_ranges_and_resampled():
|
||||
runner = _runner(num_envs=16, domain_rand=DR)
|
||||
runner.reset()
|
||||
for name, (lo, hi) in (
|
||||
("friction_scale", (0.6, 1.6)),
|
||||
("damping_scale", (0.6, 1.6)),
|
||||
("torque_scale", (0.85, 1.15)),
|
||||
):
|
||||
vals = runner._dr_scales[name]
|
||||
assert torch.all(vals >= lo) and torch.all(vals <= hi)
|
||||
# 16 independent uniform samples are never all identical.
|
||||
assert vals.std() > 0
|
||||
|
||||
before = runner._dr_scales["friction_scale"].clone()
|
||||
runner.reset()
|
||||
assert not torch.equal(before, runner._dr_scales["friction_scale"])
|
||||
|
||||
delays = runner._dr_delay
|
||||
assert torch.all(delays >= 0) and torch.all(delays <= 2)
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_dr_disabled_is_noop():
|
||||
runner = _runner(num_envs=2, domain_rand={})
|
||||
runner.reset()
|
||||
for vals in runner._dr_scales.values():
|
||||
assert torch.all(vals == 1.0)
|
||||
assert runner._max_delay == 0
|
||||
assert runner._qpos_noise_std == 0.0
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_action_delay_buffer_returns_lagged_action():
|
||||
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
|
||||
runner.reset()
|
||||
runner._dr_delay = torch.tensor([0, 1, 2])
|
||||
runner._action_buf.zero_()
|
||||
|
||||
a1 = torch.tensor([[1.0], [1.0], [1.0]])
|
||||
a2 = torch.tensor([[2.0], [2.0], [2.0]])
|
||||
a3 = torch.tensor([[3.0], [3.0], [3.0]])
|
||||
|
||||
d1 = runner._apply_action_delay(a1)
|
||||
d2 = runner._apply_action_delay(a2)
|
||||
d3 = runner._apply_action_delay(a3)
|
||||
|
||||
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
|
||||
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
|
||||
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── Initial-state randomization ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_wide_pendulum_init_actually_applied():
|
||||
runner = _runner(num_envs=32)
|
||||
qpos, _ = runner._sim_reset(torch.arange(32))
|
||||
pend_angles = qpos[:, 1]
|
||||
# With ±180° init range, samples must spread far beyond the old ±0.05.
|
||||
assert pend_angles.abs().max() > 1.0
|
||||
assert pend_angles.std() > 0.5
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_sim_reset_returns_full_batch():
|
||||
runner = _runner(num_envs=4)
|
||||
runner.reset()
|
||||
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
|
||||
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
|
||||
|
||||
|
||||
def test_mjx_runner_smoke():
|
||||
pytest.importorskip("jax")
|
||||
pytest.importorskip("mujoco.mjx")
|
||||
from src.runners.mjx import MJXRunner, MJXRunnerConfig
|
||||
|
||||
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
runner = MJXRunner(
|
||||
env=env,
|
||||
config=MJXRunnerConfig(
|
||||
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
|
||||
),
|
||||
)
|
||||
obs, _ = runner.reset()
|
||||
raw_dim = env.observation_space.shape[0]
|
||||
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
|
||||
|
||||
for _ in range(3):
|
||||
actions = torch.rand(4, 1) * 2 - 1
|
||||
obs, rewards, term, trunc, _ = runner.step(actions)
|
||||
assert torch.isfinite(obs).all()
|
||||
assert torch.isfinite(rewards).all()
|
||||
|
||||
# Wide pendulum init must reach MJX resets too.
|
||||
qpos, qvel = runner._sim_reset(torch.arange(4))
|
||||
assert qpos.shape[0] == 4 # full batch
|
||||
runner.close()
|
||||
47
train.py
47
train.py
@@ -1,47 +0,0 @@
|
||||
import hydra
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
from src.core.env import ActuatorConfig
|
||||
|
||||
|
||||
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return CartPoleConfig(**env_dict)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
env_config = _build_env_config(cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
# Build ClearML task name dynamically from Hydra config group choices
|
||||
if not training_dict.get("clearml_task"):
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "env")
|
||||
runner_name = choices.get("runner", "runner")
|
||||
training_name = choices.get("training", "algo")
|
||||
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
env = CartPoleEnv(env_config)
|
||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user