Compare commits
9 Commits
feature/ro
...
feature/si
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e0836e1bc | |||
| a98e86ef66 | |||
| 4210b6cb53 | |||
| a6fbde798a | |||
| 56499ebe97 | |||
| b37cd26690 | |||
| 8cc84d6a21 | |||
| 8ed9afe583 | |||
| 5880997786 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -5,6 +5,11 @@
|
|||||||
outputs/
|
outputs/
|
||||||
runs/
|
runs/
|
||||||
smac3_output/
|
smac3_output/
|
||||||
|
training_log.txt
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Real-robot capture data (large .npz recordings)
|
||||||
|
assets/**/recordings/
|
||||||
|
|
||||||
# MuJoCo
|
# MuJoCo
|
||||||
MUJOCO_LOG.TXT
|
MUJOCO_LOG.TXT
|
||||||
|
|||||||
64
README.md
Normal file
64
README.md
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# RL-Framework
|
||||||
|
|
||||||
|
A small, fast RL framework for training sim2real policies on a 3D-printed
|
||||||
|
rotary (Furuta) cartpole — built to scale from a laptop CPU to a GPU worker
|
||||||
|
(ClearML) without code changes, and to grow into more robots and simulators.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Three orthogonal pieces, composed by Hydra config groups:
|
||||||
|
|
||||||
|
| Piece | Role | Implementations |
|
||||||
|
|---|---|---|
|
||||||
|
| **Env** (`src/envs/`) | Task logic: obs / reward / termination / init distribution. Pure torch, batched, backend-agnostic. | `rotary_cartpole` |
|
||||||
|
| **Runner** (`src/runners/`) | Physics + sim2real plumbing (DR, sensor noise, action delay, history buffer). | `mujoco` (CPU), `mjx` (GPU/JAX), `serial` (real ESP32 robot) |
|
||||||
|
| **Trainer** (`src/training/`) | skrl PPO + shared MLP with optional history encoder. | `ppo`, `ppo_mjx`, `ppo_single`, `ppo_real` |
|
||||||
|
|
||||||
|
The robot itself is described once in `assets/<robot>/robot.yaml`
|
||||||
|
(URDF + identified motor model) and shared by **training, sysid and
|
||||||
|
deployment** — the motor model (bias → deadzone → gear compensation,
|
||||||
|
Coulomb + Stribeck friction, viscous damping, first-order lag) is
|
||||||
|
implemented in `src/core/robot.py` and mirrored exactly in the MJX JIT
|
||||||
|
step (`src/runners/mjx.py`).
|
||||||
|
|
||||||
|
## Train
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CPU (64 parallel MuJoCo envs)
|
||||||
|
python scripts/train.py env=rotary_cartpole runner=mujoco training=ppo
|
||||||
|
|
||||||
|
# GPU (1024 MJX envs) — local
|
||||||
|
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx
|
||||||
|
|
||||||
|
# GPU — remote on ClearML gpu-queue
|
||||||
|
python scripts/train.py env=rotary_cartpole runner=mjx training=ppo_mjx training.remote=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Videos and scalars stream to ClearML. Checkpoints land in `runs/`.
|
||||||
|
|
||||||
|
## Sim2real recipe
|
||||||
|
|
||||||
|
1. **Capture** real trajectories: `python -m src.sysid.capture` (writes `.npz` to `assets/<robot>/recordings/`).
|
||||||
|
2. **Identify** physics: `python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz`
|
||||||
|
— CMA-ES fits inertials/joint dynamics against the recording (motor model is locked from the unified sysid). Writes `sysid_result.json` + `robot_tuned.yaml` + `*_tuned.urdf`.
|
||||||
|
3. **Validate** the fit: `python -m src.sysid.visualize`, then copy `robot_tuned.yaml` → `robot.yaml`.
|
||||||
|
4. **Train with DR + history**: the runner randomizes friction/damping/torque scales, sensor noise and action latency per episode (`configs/runner/mjx.yaml: domain_rand`), and appends a 10-step (obs, action) history to the observation so the policy can implicitly identify the current dynamics (`history_length`).
|
||||||
|
5. **Deploy**: `mjpython scripts/eval.py env=rotary_cartpole runner=serial checkpoint=runs/<run>/checkpoints/agent_X.pt`
|
||||||
|
|
||||||
|
## Other tools
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mjpython scripts/viz.py env=rotary_cartpole # keyboard-drive the sim
|
||||||
|
mjpython scripts/viz.py runner=serial # digital twin of the real robot
|
||||||
|
python scripts/hpo.py env=rotary_cartpole training=ppo_single # ClearML + SMAC3 HPO
|
||||||
|
pytest tests/ # unit tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Adding a robot / simulator
|
||||||
|
|
||||||
|
- **Robot**: drop `assets/<name>/` (URDF + `robot.yaml`), subclass `BaseEnv`
|
||||||
|
(obs/reward/termination/`initial_state_ranges`), register in `src/core/registry.py`,
|
||||||
|
add `configs/env/<name>.yaml`.
|
||||||
|
- **Simulator**: subclass `BaseRunner` and implement `_sim_initialize`,
|
||||||
|
`_sim_step`, `_sim_reset` (full-batch return) — DR, history and the
|
||||||
|
env-side logic come for free. Register in `scripts/train.py: RUNNER_REGISTRY`.
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?>
|
|
||||||
<robot name="cartpole">
|
|
||||||
|
|
||||||
<!-- World link (fixed base) -->
|
|
||||||
<link name="world"/>
|
|
||||||
|
|
||||||
<!-- Cart (slides along x-axis) -->
|
|
||||||
<link name="cart">
|
|
||||||
<inertial>
|
|
||||||
<mass value="1.0"/>
|
|
||||||
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
|
|
||||||
</inertial>
|
|
||||||
<visual>
|
|
||||||
<geometry>
|
|
||||||
<box size="0.3 0.2 0.1"/>
|
|
||||||
</geometry>
|
|
||||||
</visual>
|
|
||||||
<collision>
|
|
||||||
<geometry>
|
|
||||||
<box size="0.3 0.2 0.1"/>
|
|
||||||
</geometry>
|
|
||||||
</collision>
|
|
||||||
</link>
|
|
||||||
|
|
||||||
<!-- Cart slides along x-axis -->
|
|
||||||
<joint name="cart_joint" type="prismatic">
|
|
||||||
<parent link="world"/>
|
|
||||||
<child link="cart"/>
|
|
||||||
<axis xyz="1 0 0"/>
|
|
||||||
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
|
|
||||||
</joint>
|
|
||||||
|
|
||||||
<!-- Pole (rotates around y-axis, attached on top of cart) -->
|
|
||||||
<link name="pole">
|
|
||||||
<inertial>
|
|
||||||
<origin xyz="0 0 0.3"/>
|
|
||||||
<mass value="0.1"/>
|
|
||||||
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
|
|
||||||
</inertial>
|
|
||||||
<visual>
|
|
||||||
<origin xyz="0 0 0.3"/>
|
|
||||||
<geometry>
|
|
||||||
<cylinder radius="0.02" length="0.6"/>
|
|
||||||
</geometry>
|
|
||||||
</visual>
|
|
||||||
<collision>
|
|
||||||
<origin xyz="0 0 0.3"/>
|
|
||||||
<geometry>
|
|
||||||
<cylinder radius="0.02" length="0.6"/>
|
|
||||||
</geometry>
|
|
||||||
</collision>
|
|
||||||
</link>
|
|
||||||
|
|
||||||
<!-- Pole rotates freely (no motor) -->
|
|
||||||
<joint name="pole_joint" type="revolute">
|
|
||||||
<parent link="cart"/>
|
|
||||||
<child link="pole"/>
|
|
||||||
<origin xyz="0 0 0.05"/>
|
|
||||||
<axis xyz="0 1 0"/>
|
|
||||||
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
|
|
||||||
<dynamics damping="0.0" friction="0.0"/>
|
|
||||||
</joint>
|
|
||||||
|
|
||||||
</robot>
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
# Classic cartpole — robot hardware config.
|
|
||||||
|
|
||||||
urdf: cartpole.urdf
|
|
||||||
|
|
||||||
actuators:
|
|
||||||
- joint: cart_joint
|
|
||||||
type: motor
|
|
||||||
gear: 10.0
|
|
||||||
ctrl_range: [-1.0, 1.0]
|
|
||||||
damping: 0.05
|
|
||||||
10
assets/motor/hardware.yaml
Normal file
10
assets/motor/hardware.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Motor-only hardware config
|
||||||
|
# Encoder and motor constants for the motor-only sysid capture.
|
||||||
|
|
||||||
|
encoder:
|
||||||
|
ppr: 11 # pulses per revolution (before quadrature)
|
||||||
|
gear_ratio: 30.0 # gearbox ratio
|
||||||
|
# counts_per_rev = ppr × gear_ratio × 4 (quadrature) = 1320
|
||||||
|
|
||||||
|
motor:
|
||||||
|
max_pwm: 255 # maximum PWM command accepted by firmware
|
||||||
40
assets/motor/motor.xml
Normal file
40
assets/motor/motor.xml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<mujoco model="motor_sysid">
|
||||||
|
<compiler angle="radian" autolimits="true"/>
|
||||||
|
<option timestep="0.002" integrator="Euler"/>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light pos="0 0 1" dir="0 0 -1"/>
|
||||||
|
|
||||||
|
<!-- Fixed base (gearbox housing) -->
|
||||||
|
<body name="base" pos="0 0 0.15">
|
||||||
|
<geom type="cylinder" size="0.04 0.06" mass="0.921"
|
||||||
|
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
|
||||||
|
|
||||||
|
<!-- Arm: rotates around motor_joint (z-axis) -->
|
||||||
|
<body name="arm" pos="0 0 0">
|
||||||
|
<joint name="motor_joint" type="hinge" axis="0 0 1"
|
||||||
|
range="-1.5708 1.5708"
|
||||||
|
damping="0.001" armature="0.0001" frictionloss="0.03"/>
|
||||||
|
|
||||||
|
<!-- Rotor disk: mass is tunable by the optimizer -->
|
||||||
|
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
|
||||||
|
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
|
||||||
|
contype="0" conaffinity="0"/>
|
||||||
|
|
||||||
|
<!-- Arm load: lightweight arm attached to motor shaft -->
|
||||||
|
<geom name="arm_load" type="capsule" size="0.004"
|
||||||
|
fromto="0 0 0 -0.014 0.002 0.016"
|
||||||
|
mass="0.021" rgba="0.8 0.3 0.1 1"
|
||||||
|
contype="0" conaffinity="0"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<general name="motor" joint="motor_joint"
|
||||||
|
gear="0.064"
|
||||||
|
ctrllimited="true" ctrlrange="-1 1"
|
||||||
|
dyntype="filter" dynprm="0.03"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
42
assets/motor/motor_bare.xml
Normal file
42
assets/motor/motor_bare.xml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<!-- Motor-only model WITHOUT arm/pendulum load.
|
||||||
|
Use for arm-off sysid captures. arm_load mass set to ~0. -->
|
||||||
|
<mujoco model="motor_sysid_bare">
|
||||||
|
<compiler angle="radian" autolimits="true"/>
|
||||||
|
<option timestep="0.002" integrator="Euler"/>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light pos="0 0 1" dir="0 0 -1"/>
|
||||||
|
|
||||||
|
<!-- Fixed base (gearbox housing) -->
|
||||||
|
<body name="base" pos="0 0 0.15">
|
||||||
|
<geom type="cylinder" size="0.04 0.06" mass="0.921"
|
||||||
|
rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0"/>
|
||||||
|
|
||||||
|
<!-- Arm: rotates around motor_joint (z-axis) -->
|
||||||
|
<body name="arm" pos="0 0 0">
|
||||||
|
<joint name="motor_joint" type="hinge" axis="0 0 1"
|
||||||
|
range="-1.5708 1.5708"
|
||||||
|
damping="0.001" armature="0.0001" frictionloss="0.03"/>
|
||||||
|
|
||||||
|
<!-- Rotor disk: mass is tunable by the optimizer -->
|
||||||
|
<geom name="rotor_disk" type="cylinder" size="0.008 0.004"
|
||||||
|
mass="0.012" pos="0 0 0" rgba="0.6 0.6 0.6 1"
|
||||||
|
contype="0" conaffinity="0"/>
|
||||||
|
|
||||||
|
<!-- Arm load: near-zero mass for bare-shaft sysid -->
|
||||||
|
<geom name="arm_load" type="capsule" size="0.004"
|
||||||
|
fromto="0 0 0 -0.014 0.002 0.016"
|
||||||
|
mass="0.001" rgba="0.8 0.3 0.1 0.3"
|
||||||
|
contype="0" conaffinity="0"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<general name="motor" joint="motor_joint"
|
||||||
|
gear="0.064"
|
||||||
|
ctrllimited="true" ctrlrange="-1 1"
|
||||||
|
dyntype="filter" dynprm="0.03"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
BIN
assets/motor/motor_sysid_comparison.png
Normal file
BIN
assets/motor/motor_sysid_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 463 KiB |
67
assets/motor/motor_sysid_result.json
Normal file
67
assets/motor/motor_sysid_result.json
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
{
|
||||||
|
"best_params": {
|
||||||
|
"actuator_gear_pos": 0.3711939014035462,
|
||||||
|
"actuator_gear_neg": 0.42814281188601877,
|
||||||
|
"actuator_filter_tau": 0.022300731564787457,
|
||||||
|
"motor_damping_pos": 0.0013836218905629106,
|
||||||
|
"motor_damping_neg": 0.005196351489379768,
|
||||||
|
"motor_armature": 0.0027534181478216656,
|
||||||
|
"motor_frictionloss_pos": 0.03674406439012955,
|
||||||
|
"motor_frictionloss_neg": 0.06908200024786905,
|
||||||
|
"viscous_quadratic": 0.000958226218765762,
|
||||||
|
"back_emf_gain": 0.0036492272912788297,
|
||||||
|
"stribeck_friction_boost": 0.044748043677129666,
|
||||||
|
"stribeck_vel": 4.0513395945623705,
|
||||||
|
"rotor_mass": 0.03982640764507874,
|
||||||
|
"motor_deadzone_pos": 0.14181963932762467,
|
||||||
|
"motor_deadzone_neg": 0.031454276545010214,
|
||||||
|
"action_bias": -0.007362969452509152,
|
||||||
|
"gearbox_backlash": 1.4749880999407965e-09
|
||||||
|
},
|
||||||
|
"best_cost": 0.21167928018839952,
|
||||||
|
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/motor/recordings/motor_from_cartpole_162432.npz",
|
||||||
|
"param_names": [
|
||||||
|
"actuator_gear_pos",
|
||||||
|
"actuator_gear_neg",
|
||||||
|
"actuator_filter_tau",
|
||||||
|
"motor_damping_pos",
|
||||||
|
"motor_damping_neg",
|
||||||
|
"motor_armature",
|
||||||
|
"motor_frictionloss_pos",
|
||||||
|
"motor_frictionloss_neg",
|
||||||
|
"viscous_quadratic",
|
||||||
|
"back_emf_gain",
|
||||||
|
"stribeck_friction_boost",
|
||||||
|
"stribeck_vel",
|
||||||
|
"rotor_mass",
|
||||||
|
"motor_deadzone_pos",
|
||||||
|
"motor_deadzone_neg",
|
||||||
|
"action_bias",
|
||||||
|
"gearbox_backlash"
|
||||||
|
],
|
||||||
|
"defaults": {
|
||||||
|
"actuator_gear_pos": 0.064,
|
||||||
|
"actuator_gear_neg": 0.064,
|
||||||
|
"actuator_filter_tau": 0.03,
|
||||||
|
"motor_damping_pos": 0.003,
|
||||||
|
"motor_damping_neg": 0.003,
|
||||||
|
"motor_armature": 0.0001,
|
||||||
|
"motor_frictionloss_pos": 0.03,
|
||||||
|
"motor_frictionloss_neg": 0.03,
|
||||||
|
"viscous_quadratic": 0.0,
|
||||||
|
"back_emf_gain": 0.0,
|
||||||
|
"stribeck_friction_boost": 0.0,
|
||||||
|
"stribeck_vel": 2.0,
|
||||||
|
"rotor_mass": 0.012,
|
||||||
|
"motor_deadzone_pos": 0.08,
|
||||||
|
"motor_deadzone_neg": 0.08,
|
||||||
|
"action_bias": 0.0,
|
||||||
|
"gearbox_backlash": 0.0
|
||||||
|
},
|
||||||
|
"timestamp": "2026-03-23T20:50:53.648753",
|
||||||
|
"history_summary": {
|
||||||
|
"first_cost": 5.105499579419285,
|
||||||
|
"final_cost": 0.21167928018839952,
|
||||||
|
"generations": 500
|
||||||
|
}
|
||||||
|
}
|
||||||
19
assets/motor/motor_tuned.xml
Normal file
19
assets/motor/motor_tuned.xml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version='1.0' encoding='utf-8'?>
|
||||||
|
<mujoco model="motor_sysid">
|
||||||
|
<compiler angle="radian" autolimits="true" />
|
||||||
|
<option timestep="0.002" integrator="Euler" />
|
||||||
|
<worldbody>
|
||||||
|
<light pos="0 0 1" dir="0 0 -1" />
|
||||||
|
<body name="base" pos="0 0 0.15">
|
||||||
|
<geom type="cylinder" size="0.04 0.06" mass="0.921" rgba="0.3 0.3 0.3 1" contype="0" conaffinity="0" />
|
||||||
|
<body name="arm" pos="0 0 0">
|
||||||
|
<joint name="motor_joint" type="hinge" axis="0 0 1" range="-1.5708 1.5708" damping="0" armature="0.0027534181478216656" frictionloss="0" />
|
||||||
|
<geom name="rotor_disk" type="cylinder" size="0.008 0.004" mass="0.03982640764507874" pos="0 0 0" rgba="0.6 0.6 0.6 1" contype="0" conaffinity="0" />
|
||||||
|
<geom name="arm_load" type="capsule" size="0.004" fromto="0 0 0 -0.014 0.002 0.016" mass="0.021" rgba="0.8 0.3 0.1 1" contype="0" conaffinity="0" />
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
<actuator>
|
||||||
|
<general name="motor" joint="motor_joint" gear="0.3996683566447825" ctrllimited="true" ctrlrange="-1 1" dyntype="filter" dynprm="0.022300731564787457" />
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
6
assets/motor/robot.yaml
Normal file
6
assets/motor/robot.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# Motor-only sysid config
|
||||||
|
# Minimal config for the motor-only identification pipeline.
|
||||||
|
# The optimizer patches motor.xml in-memory; this file tells it
|
||||||
|
# which MJCF to load and provides encoder/hardware constants.
|
||||||
|
|
||||||
|
mjcf: motor.xml
|
||||||
4
assets/motor/robot_bare.yaml
Normal file
4
assets/motor/robot_bare.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
# Motor-only sysid config — bare shaft (no arm/pendulum load)
|
||||||
|
# Use with arm-off captures to identify pure motor dynamics.
|
||||||
|
|
||||||
|
mjcf: motor_bare.xml
|
||||||
23
assets/motor/robot_tuned.yaml
Normal file
23
assets/motor/robot_tuned.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Tuned motor config — generated by src.sysid.motor.optimize
|
||||||
|
# Original: robot.yaml
|
||||||
|
|
||||||
|
mjcf: motor_tuned.xml
|
||||||
|
joints:
|
||||||
|
motor_joint:
|
||||||
|
armature: 0.002753
|
||||||
|
frictionloss: 0.052913
|
||||||
|
hardware_realism:
|
||||||
|
actuator_gear_pos: 0.371194
|
||||||
|
actuator_gear_neg: 0.428143
|
||||||
|
motor_damping_pos: 0.001384
|
||||||
|
motor_damping_neg: 0.005196
|
||||||
|
motor_frictionloss_pos: 0.036744
|
||||||
|
motor_frictionloss_neg: 0.069082
|
||||||
|
motor_deadzone_pos: 0.14182
|
||||||
|
motor_deadzone_neg: 0.031454
|
||||||
|
action_bias: -0.007363
|
||||||
|
viscous_quadratic: 0.000958
|
||||||
|
back_emf_gain: 0.003649
|
||||||
|
stribeck_friction_boost: 0.044748
|
||||||
|
stribeck_vel: 4.05134
|
||||||
|
gearbox_backlash: 0.0
|
||||||
@@ -1,21 +1,30 @@
|
|||||||
# Tuned robot config — generated by src.sysid.optimize
|
# Canonical training model — unified sysid (cost 0.925, 475 generations).
|
||||||
|
# Source: sysid_result.json → exported via src.sysid.export.
|
||||||
|
# Key physics: ~96 ms motor lag (filter_tau), Stribeck friction, driver bias.
|
||||||
|
# Regenerate with:
|
||||||
|
# python -m src.sysid.optimize --robot-path assets/rotary_cartpole --recording <capture>.npz
|
||||||
|
# then copy robot_tuned.yaml over this file once validated
|
||||||
|
# (python -m src.sysid.visualize to compare real vs sim).
|
||||||
|
|
||||||
|
urdf: rotary_cartpole_tuned.urdf
|
||||||
|
|
||||||
urdf: rotary_cartpole.urdf
|
|
||||||
actuators:
|
actuators:
|
||||||
- joint: motor_joint
|
- joint: motor_joint
|
||||||
type: motor
|
type: motor
|
||||||
gear: [0.424182, 0.425031] # torque constant [pos, neg] (motor sysid)
|
gear: [0.846499, 1.183733] # torque constant [pos, neg]
|
||||||
ctrl_range: [-0.592, 0.592] # effective control bound (sysid-tuned)
|
ctrl_range: [-0.686251, 0.686251] # PWM saturation (MAX_MOTOR_SPEED / 255)
|
||||||
deadzone: [0.141291, 0.078015] # L298N min |ctrl| for torque [pos, neg]
|
deadzone: [0.181097, 0.202072] # L298N min |ctrl| for torque [pos, neg]
|
||||||
damping: [0.002027, 0.014665] # viscous damping [pos, neg]
|
damping: [0.013165, 0.015452] # viscous damping [pos, neg]
|
||||||
frictionloss: [0.057328, 0.053355] # Coulomb friction [pos, neg]
|
frictionloss: [0.014244, 0.001005] # Coulomb friction [pos, neg]
|
||||||
filter_tau: 0.005035 # 1st-order actuator filter (motor sysid)
|
filter_tau: 0.096263 # 1st-order actuator lag (s) — dominant!
|
||||||
viscous_quadratic: 0.000285 # velocity² drag
|
stribeck_friction_boost: 0.068594 # extra static friction near standstill
|
||||||
back_emf_gain: 0.006758 # back-EMF torque reduction
|
stribeck_vel: 5.279594 # Stribeck decay velocity (rad/s)
|
||||||
|
action_bias: 0.056566 # additive ctrl bias (driver asymmetry)
|
||||||
|
|
||||||
joints:
|
joints:
|
||||||
motor_joint:
|
motor_joint:
|
||||||
armature: 0.002773 # reflected rotor inertia (motor sysid)
|
armature: 0.001676 # reflected rotor inertia (kg·m²)
|
||||||
frictionloss: 0.0 # handled by motor model via qfrc_applied
|
frictionloss: 0.0 # handled by motor model via qfrc_applied
|
||||||
pendulum_joint:
|
pendulum_joint:
|
||||||
damping: 0.000119
|
damping: 1.2e-05
|
||||||
frictionloss: 1.0e-05
|
frictionloss: 7.2e-05
|
||||||
|
|||||||
34
assets/rotary_cartpole/robot_tuned.yaml
Normal file
34
assets/rotary_cartpole/robot_tuned.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# Tuned robot config — generated by src.sysid.optimize
|
||||||
|
# Original: robot.yaml
|
||||||
|
# Run `python -m src.sysid.visualize` to compare real vs sim.
|
||||||
|
|
||||||
|
urdf: rotary_cartpole_tuned.urdf
|
||||||
|
actuators:
|
||||||
|
- joint: motor_joint
|
||||||
|
type: motor
|
||||||
|
gear:
|
||||||
|
- 0.846499
|
||||||
|
- 1.183733
|
||||||
|
ctrl_range:
|
||||||
|
- -0.686251
|
||||||
|
- 0.686251
|
||||||
|
deadzone:
|
||||||
|
- 0.181097
|
||||||
|
- 0.202072
|
||||||
|
damping:
|
||||||
|
- 0.013165
|
||||||
|
- 0.015452
|
||||||
|
frictionloss:
|
||||||
|
- 0.014244
|
||||||
|
- 0.001005
|
||||||
|
filter_tau: 0.096263
|
||||||
|
stribeck_friction_boost: 0.068594
|
||||||
|
stribeck_vel: 5.279594
|
||||||
|
action_bias: 0.056566
|
||||||
|
joints:
|
||||||
|
motor_joint:
|
||||||
|
armature: 0.001676
|
||||||
|
frictionloss: 0.0
|
||||||
|
pendulum_joint:
|
||||||
|
damping: 1.2e-05
|
||||||
|
frictionloss: 7.2e-05
|
||||||
80
assets/rotary_cartpole/rotary_cartpole_tuned.urdf
Normal file
80
assets/rotary_cartpole/rotary_cartpole_tuned.urdf
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
<?xml version='1.0' encoding='utf-8'?>
|
||||||
|
<robot name="rotary_cartpole">
|
||||||
|
<link name="world" />
|
||||||
|
<link name="base_link">
|
||||||
|
<inertial>
|
||||||
|
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0" />
|
||||||
|
<mass value="0.921" />
|
||||||
|
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559" ixy="0.0" iyz="-0.000149" ixz="6e-06" />
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
<joint name="base_joint" type="fixed">
|
||||||
|
<parent link="world" />
|
||||||
|
<child link="base_link" />
|
||||||
|
</joint>
|
||||||
|
<link name="arm">
|
||||||
|
<inertial>
|
||||||
|
<origin xyz="0.02679980831009001 -0.015110803875962989 -0.005337417994989926" rpy="0 0 0" />
|
||||||
|
<mass value="0.04052645463607292" />
|
||||||
|
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06" ixy="0.0" iyz="7.20e-08" ixz="0.0" />
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
<joint name="motor_joint" type="revolute">
|
||||||
|
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0" />
|
||||||
|
<parent link="base_link" />
|
||||||
|
<child link="arm" />
|
||||||
|
<axis xyz="0 0 1" />
|
||||||
|
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0" />
|
||||||
|
<dynamics damping="0.001" />
|
||||||
|
</joint>
|
||||||
|
<link name="pendulum">
|
||||||
|
<inertial>
|
||||||
|
<origin xyz="0.04237886229290564 -0.05762212306183831 -0.0006039324398328591" rpy="0 0 0" />
|
||||||
|
<mass value="0.056793277126969105" />
|
||||||
|
<inertia ixx="0.0005956997356690322" iyy="8.986080748758447e-05" izz="0.0006855370063143978" ixy="-1.074983574165622e-05" iyz="0.0" ixz="0.0" />
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0" />
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
<joint name="pendulum_joint" type="continuous">
|
||||||
|
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0" />
|
||||||
|
<parent link="arm" />
|
||||||
|
<child link="pendulum" />
|
||||||
|
<axis xyz="0 -1 0" />
|
||||||
|
<dynamics damping="0.0001" />
|
||||||
|
</joint>
|
||||||
|
</robot>
|
||||||
BIN
assets/rotary_cartpole/sysid_comparison.png
Normal file
BIN
assets/rotary_cartpole/sysid_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 845 KiB |
101
assets/rotary_cartpole/sysid_result.json
Normal file
101
assets/rotary_cartpole/sysid_result.json
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
{
|
||||||
|
"best_params": {
|
||||||
|
"actuator_gear_pos": 0.8464986357922419,
|
||||||
|
"actuator_gear_neg": 1.1837332515116656,
|
||||||
|
"actuator_filter_tau": 0.09626340711982824,
|
||||||
|
"motor_damping_pos": 0.013165358841570964,
|
||||||
|
"motor_damping_neg": 0.015451769431035585,
|
||||||
|
"motor_armature": 0.0016762295220909746,
|
||||||
|
"motor_frictionloss_pos": 0.014244285616762612,
|
||||||
|
"motor_frictionloss_neg": 0.0010053574956085138,
|
||||||
|
"stribeck_friction_boost": 0.06859381652654695,
|
||||||
|
"stribeck_vel": 5.279593819777324,
|
||||||
|
"motor_deadzone_pos": 0.1810971675049997,
|
||||||
|
"motor_deadzone_neg": 0.20207167371643614,
|
||||||
|
"action_bias": 0.05656632270757292,
|
||||||
|
"arm_mass": 0.04052645463607292,
|
||||||
|
"arm_com_x": 0.02679980831009001,
|
||||||
|
"arm_com_y": -0.015110803875962989,
|
||||||
|
"arm_com_z": -0.005337417994989926,
|
||||||
|
"pendulum_mass": 0.056793277126969105,
|
||||||
|
"pendulum_com_x": 0.04237886229290564,
|
||||||
|
"pendulum_com_y": -0.05762212306183831,
|
||||||
|
"pendulum_com_z": -0.0006039324398328591,
|
||||||
|
"pendulum_ixx": 0.0005956997356690322,
|
||||||
|
"pendulum_iyy": 8.986080748758447e-05,
|
||||||
|
"pendulum_izz": 0.0006855370063143978,
|
||||||
|
"pendulum_ixy": -1.074983574165622e-05,
|
||||||
|
"pendulum_damping": 1.1718327130851149e-05,
|
||||||
|
"pendulum_frictionloss": 7.204244487092445e-05,
|
||||||
|
"ctrl_limit": 0.6862506602999546
|
||||||
|
},
|
||||||
|
"best_cost": 0.9250391788859942,
|
||||||
|
"recording": "/Users/victormylle/Library/CloudStorage/SeaDrive-VictorMylle(cloud.optimize-it.be)/My Libraries/Projects/AI/RL-Framework/assets/rotary_cartpole/recordings/capture_20260328_153749.npz",
|
||||||
|
"param_names": [
|
||||||
|
"actuator_gear_pos",
|
||||||
|
"actuator_gear_neg",
|
||||||
|
"actuator_filter_tau",
|
||||||
|
"motor_damping_pos",
|
||||||
|
"motor_damping_neg",
|
||||||
|
"motor_armature",
|
||||||
|
"motor_frictionloss_pos",
|
||||||
|
"motor_frictionloss_neg",
|
||||||
|
"stribeck_friction_boost",
|
||||||
|
"stribeck_vel",
|
||||||
|
"motor_deadzone_pos",
|
||||||
|
"motor_deadzone_neg",
|
||||||
|
"action_bias",
|
||||||
|
"arm_mass",
|
||||||
|
"arm_com_x",
|
||||||
|
"arm_com_y",
|
||||||
|
"arm_com_z",
|
||||||
|
"pendulum_mass",
|
||||||
|
"pendulum_com_x",
|
||||||
|
"pendulum_com_y",
|
||||||
|
"pendulum_com_z",
|
||||||
|
"pendulum_ixx",
|
||||||
|
"pendulum_iyy",
|
||||||
|
"pendulum_izz",
|
||||||
|
"pendulum_ixy",
|
||||||
|
"pendulum_damping",
|
||||||
|
"pendulum_frictionloss",
|
||||||
|
"ctrl_limit"
|
||||||
|
],
|
||||||
|
"defaults": {
|
||||||
|
"actuator_gear_pos": 0.371194,
|
||||||
|
"actuator_gear_neg": 0.428143,
|
||||||
|
"actuator_filter_tau": 0.022301,
|
||||||
|
"motor_damping_pos": 0.001384,
|
||||||
|
"motor_damping_neg": 0.005196,
|
||||||
|
"motor_armature": 0.002753,
|
||||||
|
"motor_frictionloss_pos": 0.036744,
|
||||||
|
"motor_frictionloss_neg": 0.069082,
|
||||||
|
"stribeck_friction_boost": 0.0,
|
||||||
|
"stribeck_vel": 2.0,
|
||||||
|
"motor_deadzone_pos": 0.14182,
|
||||||
|
"motor_deadzone_neg": 0.031454,
|
||||||
|
"action_bias": 0.0,
|
||||||
|
"arm_mass": 0.0211,
|
||||||
|
"arm_com_x": -0.0071,
|
||||||
|
"arm_com_y": 0.00085,
|
||||||
|
"arm_com_z": 0.00795,
|
||||||
|
"pendulum_mass": 0.03937,
|
||||||
|
"pendulum_com_x": 0.06025,
|
||||||
|
"pendulum_com_y": -0.07602,
|
||||||
|
"pendulum_com_z": -0.00346,
|
||||||
|
"pendulum_ixx": 6.2e-05,
|
||||||
|
"pendulum_iyy": 3.7e-05,
|
||||||
|
"pendulum_izz": 7.83e-05,
|
||||||
|
"pendulum_ixy": -6.93e-06,
|
||||||
|
"pendulum_damping": 0.0001,
|
||||||
|
"pendulum_frictionloss": 0.0001,
|
||||||
|
"ctrl_limit": 0.588
|
||||||
|
},
|
||||||
|
"preprocess_vel": true,
|
||||||
|
"timestamp": "2026-03-28T17:09:39.241413",
|
||||||
|
"history_summary": {
|
||||||
|
"first_cost": 13.490216059926947,
|
||||||
|
"final_cost": 0.9250391788859942,
|
||||||
|
"generations": 475
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- env: cartpole
|
- env: rotary_cartpole
|
||||||
- runner: mujoco
|
- runner: mujoco
|
||||||
- training: ppo
|
- training: ppo
|
||||||
- _self_
|
- _self_
|
||||||
7
configs/env/cartpole.yaml
vendored
7
configs/env/cartpole.yaml
vendored
@@ -1,7 +0,0 @@
|
|||||||
max_steps: 500
|
|
||||||
robot_path: assets/cartpole
|
|
||||||
angle_threshold: 0.418
|
|
||||||
cart_limit: 2.4
|
|
||||||
reward_alive: 1.0
|
|
||||||
reward_pole_upright_scale: 1.0
|
|
||||||
reward_action_penalty_scale: 0.01
|
|
||||||
9
configs/env/rotary_cartpole.yaml
vendored
9
configs/env/rotary_cartpole.yaml
vendored
@@ -1,11 +1,18 @@
|
|||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
robot_path: assets/rotary_cartpole
|
robot_path: assets/rotary_cartpole
|
||||||
reward_upright_scale: 1.0
|
reward_upright_scale: 1.0
|
||||||
|
alive_bonus: 0.25 # per-step survival bonus (living must beat dying)
|
||||||
|
balance_bonus: 2.0 # extra reward for upright AND still (beats spinning)
|
||||||
|
balance_vel_scale: 0.5 # how fast the balance bonus decays with pendulum speed
|
||||||
|
|
||||||
# ── Regularisation penalties (prevent fast spinning) ─────────────────
|
# ── Regularisation penalties (prevent fast spinning) ─────────────────
|
||||||
motor_vel_penalty: 0.01 # penalise high motor angular velocity
|
motor_vel_penalty: 0.01 # penalise high motor angular velocity
|
||||||
motor_angle_penalty: 0.05 # penalise deviation from centre
|
motor_angle_penalty: 0.05 # penalise deviation from centre
|
||||||
action_penalty: 0.05 # penalise large actions (energy cost)
|
action_penalty: 0.05 # penalise large actions (energy cost)
|
||||||
|
action_rate_penalty: 0.01 # penalise action changes (real-motor smoothness)
|
||||||
|
|
||||||
|
# ── Initial state randomisation ──────────────────────────────────────
|
||||||
|
pendulum_init_range_deg: 180.0 # pendulum starts in [-180°, +180°]
|
||||||
|
|
||||||
# ── Software safety limit (env-level, always applied) ────────────────
|
# ── Software safety limit (env-level, always applied) ────────────────
|
||||||
motor_angle_limit_deg: 90.0 # terminate episode if motor exceeds ±90°
|
motor_angle_limit_deg: 90.0 # terminate episode if motor exceeds ±90°
|
||||||
@@ -16,4 +23,6 @@ hpo:
|
|||||||
motor_vel_penalty: {min: 0.001, max: 0.1}
|
motor_vel_penalty: {min: 0.001, max: 0.1}
|
||||||
motor_angle_penalty: {min: 0.01, max: 0.2}
|
motor_angle_penalty: {min: 0.01, max: 0.2}
|
||||||
action_penalty: {min: 0.01, max: 0.2}
|
action_penalty: {min: 0.01, max: 0.2}
|
||||||
|
action_rate_penalty: {min: 0.001, max: 0.1}
|
||||||
|
pendulum_init_range_deg: {min: 30.0, max: 180.0}
|
||||||
max_steps: {values: [500, 1000, 2000]}
|
max_steps: {values: [500, 1000, 2000]}
|
||||||
@@ -2,3 +2,15 @@ num_envs: 1024 # MJX shines with many parallel envs
|
|||||||
device: auto # auto = cuda if available, else cpu
|
device: auto # auto = cuda if available, else cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 10
|
substeps: 10
|
||||||
|
history_length: 10 # (obs, action) window for implicit adaptation
|
||||||
|
|
||||||
|
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||||
|
# Full DR on GPU: latency + sensor noise + per-env dynamics scales
|
||||||
|
# (friction/damping/torque) are all applied inside the JIT step.
|
||||||
|
domain_rand:
|
||||||
|
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||||
|
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||||
|
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||||
|
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
|
||||||
|
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||||
|
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||||
|
|||||||
@@ -2,3 +2,15 @@ num_envs: 64
|
|||||||
device: auto # auto = cuda if available, else cpu
|
device: auto # auto = cuda if available, else cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 10
|
substeps: 10
|
||||||
|
history_length: 10 # (obs, action) window for implicit adaptation
|
||||||
|
|
||||||
|
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||||
|
# Noise/delay levels anchored to the real recordings (~50 Hz, ~0.5 rad/s
|
||||||
|
# velocity noise, ≤1-step latency). Set domain_rand: {} to disable.
|
||||||
|
domain_rand:
|
||||||
|
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||||
|
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||||
|
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||||
|
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier
|
||||||
|
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||||
|
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||||
|
|||||||
@@ -5,3 +5,11 @@ num_envs: 1
|
|||||||
device: cpu
|
device: cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 10
|
substeps: 10
|
||||||
|
history_length: 10
|
||||||
|
|
||||||
|
# Clean by default (deterministic eval). Confirming-experiment example —
|
||||||
|
# re-eval an existing checkpoint in sim with a fixed 1-step action delay:
|
||||||
|
# mjpython scripts/eval.py env=rotary_cartpole runner=mujoco_single \
|
||||||
|
# checkpoint=runs/.../agent_XXXX.pt \
|
||||||
|
# '++runner.domain_rand.action_delay_steps=[1,1]'
|
||||||
|
domain_rand: {}
|
||||||
|
|||||||
@@ -8,3 +8,4 @@ port: /dev/cu.usbserial-0001
|
|||||||
baud: 115200
|
baud: 115200
|
||||||
dt: 0.02 # control loop period (50 Hz, matches training)
|
dt: 0.02 # control loop period (50 Hz, matches training)
|
||||||
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
|
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
|
||||||
|
history_length: 10 # must match training runner
|
||||||
|
|||||||
@@ -1,23 +1,31 @@
|
|||||||
hidden_sizes: [128, 128]
|
# PPO defaults — sized for the CPU MuJoCo runner (64 parallel envs).
|
||||||
total_timesteps: 5000000
|
# 128 rollout steps × 64 envs ≈ 8K samples per update.
|
||||||
rollout_steps: 1024
|
|
||||||
learning_epochs: 4
|
hidden_sizes: [256, 256]
|
||||||
|
total_timesteps: 500000 # × 64 envs = 32M env steps
|
||||||
|
rollout_steps: 128
|
||||||
|
learning_epochs: 5
|
||||||
mini_batches: 4
|
mini_batches: 4
|
||||||
discount_factor: 0.99
|
discount_factor: 0.99
|
||||||
gae_lambda: 0.95
|
gae_lambda: 0.95
|
||||||
learning_rate: 0.0003
|
learning_rate: 0.0003
|
||||||
clip_ratio: 0.2
|
clip_ratio: 0.2
|
||||||
value_loss_scale: 0.5
|
value_loss_scale: 0.5
|
||||||
entropy_loss_scale: 0.05
|
entropy_loss_scale: 0.01
|
||||||
|
kl_threshold: 0.01 # KL-adaptive LR; 0 = fixed learning rate
|
||||||
log_interval: 1000
|
log_interval: 1000
|
||||||
checkpoint_interval: 50000
|
checkpoint_interval: 50000
|
||||||
|
|
||||||
initial_log_std: 0.5
|
initial_log_std: -0.5
|
||||||
min_log_std: -2.0
|
min_log_std: -4.0
|
||||||
max_log_std: 2.0
|
max_log_std: 2.0
|
||||||
|
|
||||||
record_video_every: 10000
|
record_video_every: 10000
|
||||||
|
|
||||||
|
# History encoder output dim — the window size itself comes from
|
||||||
|
# runner.history_length (single source of truth).
|
||||||
|
embedding_dim: 32
|
||||||
|
|
||||||
# ClearML remote execution (GPU worker)
|
# ClearML remote execution (GPU worker)
|
||||||
remote: false
|
remote: false
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
# PPO tuned for MJX (1024+ parallel envs on GPU).
|
# PPO sized for MJX (1024+ parallel envs on GPU).
|
||||||
# Inherits defaults + HPO ranges from ppo.yaml.
|
# Inherits defaults + HPO ranges from ppo.yaml.
|
||||||
# With 1024 envs, each timestep collects 1024 samples, so total_timesteps
|
#
|
||||||
# can be much lower than the CPU config.
|
# Short rollouts × many envs is the GPU-PPO sweet spot:
|
||||||
|
# 24 steps × 1024 envs ≈ 25K samples per update (~6K per mini-batch).
|
||||||
|
# (The old rollout_steps=2048 inherited from the CPU config meant a
|
||||||
|
# 2M-sample memory per update — GBs of VRAM and glacial updates.)
|
||||||
|
|
||||||
defaults:
|
defaults:
|
||||||
- ppo
|
- ppo
|
||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
total_timesteps: 300000 # 300K × 1024 envs ≈ 307M env steps
|
rollout_steps: 24
|
||||||
mini_batches: 32 # keep mini-batch size similar (~32K)
|
mini_batches: 4
|
||||||
learning_rate: 0.001 # ~3x higher LR for 16x larger batch (sqrt scaling)
|
learning_epochs: 5
|
||||||
|
learning_rate: 0.0003 # KL-adaptive scheduler handles the rest
|
||||||
|
total_timesteps: 100000 # × 1024 envs ≈ 100M env steps
|
||||||
log_interval: 100
|
log_interval: 100
|
||||||
checkpoint_interval: 10000
|
checkpoint_interval: 10000
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
torch
|
torch
|
||||||
gymnasium
|
gymnasium==1.2.3
|
||||||
hydra-core
|
hydra-core
|
||||||
omegaconf
|
omegaconf
|
||||||
mujoco
|
mujoco==3.5.0
|
||||||
mujoco-mjx
|
mujoco-mjx==3.5.0
|
||||||
jax
|
jax[cuda12]==0.9.1 ; sys_platform == "linux"
|
||||||
skrl[torch]
|
jax==0.9.1 ; sys_platform != "linux"
|
||||||
|
skrl[torch]==1.4.3
|
||||||
clearml
|
clearml
|
||||||
imageio
|
imageio
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
|
|||||||
@@ -74,14 +74,31 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
|||||||
return tuple(sizes)
|
return tuple(sizes)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
|
||||||
|
"""Return the history encoder output dim, if present.
|
||||||
|
|
||||||
|
Lets eval reconstruct an embedding policy without knowing the training
|
||||||
|
embedding_dim — read it straight from the saved weights.
|
||||||
|
"""
|
||||||
|
if "history_encoder.fc.weight" in state_dict:
|
||||||
|
return state_dict["history_encoder.fc.weight"].shape[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_policy(
|
def load_policy(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
observation_space: spaces.Space,
|
observation_space: spaces.Space,
|
||||||
action_space: spaces.Space,
|
action_space: spaces.Space,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
|
history_length: int = 0,
|
||||||
|
raw_obs_dim: int = 0,
|
||||||
) -> tuple[SharedMLP, RunningStandardScaler]:
|
) -> tuple[SharedMLP, RunningStandardScaler]:
|
||||||
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
||||||
|
|
||||||
|
For DR + history-embedding policies (history_length > 0), the history
|
||||||
|
encoder is reconstructed too — its output dim is read back from the
|
||||||
|
saved weights.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(model, state_preprocessor) ready for inference.
|
(model, state_preprocessor) ready for inference.
|
||||||
"""
|
"""
|
||||||
@@ -89,13 +106,18 @@ def load_policy(
|
|||||||
|
|
||||||
# Infer architecture from saved weights.
|
# Infer architecture from saved weights.
|
||||||
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
hidden_sizes = _infer_hidden_sizes(ckpt["policy"])
|
||||||
|
enc_out = _infer_encoder_out_dim(ckpt["policy"])
|
||||||
|
|
||||||
# Reconstruct model.
|
# Reconstruct model — pass through the encoder config so a DR+embedding
|
||||||
|
# checkpoint rebuilds the history encoder with matching dimensions.
|
||||||
model = SharedMLP(
|
model = SharedMLP(
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
device=device,
|
device=device,
|
||||||
hidden_sizes=hidden_sizes,
|
hidden_sizes=hidden_sizes,
|
||||||
|
history_length=history_length if enc_out else 0,
|
||||||
|
raw_obs_dim=raw_obs_dim,
|
||||||
|
embedding_dim=enc_out or 32,
|
||||||
)
|
)
|
||||||
model.load_state_dict(ckpt["policy"])
|
model.load_state_dict(ckpt["policy"])
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -163,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
|||||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
choices = HydraConfig.get().runtime.choices
|
choices = HydraConfig.get().runtime.choices
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "rotary_cartpole")
|
||||||
runner_name = choices.get("runner", "mujoco_single")
|
runner_name = choices.get("runner", "mujoco_single")
|
||||||
|
|
||||||
checkpoint_path = cfg.get("checkpoint", None)
|
checkpoint_path = cfg.get("checkpoint", None)
|
||||||
@@ -194,7 +216,9 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|||||||
|
|
||||||
device = runner.device
|
device = runner.device
|
||||||
model, preprocessor = load_policy(
|
model, preprocessor = load_policy(
|
||||||
checkpoint_path, runner.observation_space, runner.action_space, device
|
checkpoint_path, runner.observation_space, runner.action_space, device,
|
||||||
|
history_length=runner.config.history_length,
|
||||||
|
raw_obs_dim=runner.env.observation_space.shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
mj_model = runner._model
|
mj_model = runner._model
|
||||||
@@ -280,7 +304,9 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|||||||
|
|
||||||
device = serial_runner.device
|
device = serial_runner.device
|
||||||
model, preprocessor = load_policy(
|
model, preprocessor = load_policy(
|
||||||
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device
|
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
|
||||||
|
history_length=serial_runner.config.history_length,
|
||||||
|
raw_obs_dim=serial_runner.env.observation_space.shape[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up digital-twin MuJoCo model for visualization.
|
# Set up digital-twin MuJoCo model for visualization.
|
||||||
@@ -307,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|||||||
if _reset_flag[0]:
|
if _reset_flag[0]:
|
||||||
_reset_flag[0] = False
|
_reset_flag[0] = False
|
||||||
serial_runner._send("M0")
|
serial_runner._send("M0")
|
||||||
serial_runner._drive_to_center()
|
obs, _ = serial_runner.reset() # drives to center + settles
|
||||||
serial_runner._wait_for_pendulum_still()
|
|
||||||
obs, _ = serial_runner.reset()
|
|
||||||
step = 0
|
step = 0
|
||||||
episode += 1
|
episode += 1
|
||||||
episode_reward = 0.0
|
episode_reward = 0.0
|
||||||
@@ -344,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
|||||||
"step", n=step, reward=round(reward.item(), 3),
|
"step", n=step, reward=round(reward.item(), 3),
|
||||||
action=round(action[0, 0].item(), 2),
|
action=round(action[0, 0].item(), 2),
|
||||||
ep_reward=round(episode_reward, 1),
|
ep_reward=round(episode_reward, 1),
|
||||||
motor_enc=state["encoder_count"],
|
motor_deg=round(math.degrees(state["motor_rad"]), 1),
|
||||||
pend_deg=round(state["pendulum_angle"], 1),
|
pend_deg=round(math.degrees(state["pend_rad"]), 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for safety / disconnection.
|
# Check for safety / disconnection.
|
||||||
|
|||||||
@@ -352,7 +352,7 @@ def main() -> None:
|
|||||||
reuse_last_task_id=False,
|
reuse_last_task_id=False,
|
||||||
)
|
)
|
||||||
task.set_base_docker(
|
task.set_base_docker(
|
||||||
docker_image="registry.kube.optimize/worker-image:latest",
|
docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||||
docker_arguments=[
|
docker_arguments=[
|
||||||
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
||||||
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
"""Unified CLI for motor-only system identification.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python scripts/motor_sysid.py capture --duration 20
|
|
||||||
python scripts/motor_sysid.py optimize --recording assets/motor/recordings/<file>.npz
|
|
||||||
python scripts/motor_sysid.py visualize --recording assets/motor/recordings/<file>.npz
|
|
||||||
python scripts/motor_sysid.py export --result assets/motor/motor_sysid_result.json
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure project root is on sys.path
|
|
||||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
|
||||||
if _PROJECT_ROOT not in sys.path:
|
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
|
||||||
print(
|
|
||||||
"Motor System Identification\n"
|
|
||||||
"===========================\n"
|
|
||||||
"Usage: python scripts/motor_sysid.py <command> [options]\n"
|
|
||||||
"\n"
|
|
||||||
"Commands:\n"
|
|
||||||
" capture Record motor trajectory under PRBS excitation\n"
|
|
||||||
" optimize Run CMA-ES to fit motor parameters\n"
|
|
||||||
" visualize Plot real vs simulated motor response\n"
|
|
||||||
" export Write tuned MJCF + robot.yaml files\n"
|
|
||||||
"\n"
|
|
||||||
"Workflow:\n"
|
|
||||||
" 1. Flash sysid firmware to ESP32 (motor-only, no limits)\n"
|
|
||||||
" 2. python scripts/motor_sysid.py capture --duration 20\n"
|
|
||||||
" 3. python scripts/motor_sysid.py optimize --recording <file>.npz\n"
|
|
||||||
" 4. python scripts/motor_sysid.py visualize --recording <file>.npz\n"
|
|
||||||
"\n"
|
|
||||||
"Run '<command> --help' for command-specific options."
|
|
||||||
)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
command = sys.argv[1]
|
|
||||||
sys.argv = [f"motor_sysid {command}"] + sys.argv[2:]
|
|
||||||
|
|
||||||
if command == "capture":
|
|
||||||
from src.sysid.motor.capture import main as cmd_main
|
|
||||||
elif command == "optimize":
|
|
||||||
from src.sysid.motor.optimize import main as cmd_main
|
|
||||||
elif command == "visualize":
|
|
||||||
from src.sysid.motor.visualize import main as cmd_main
|
|
||||||
elif command == "export":
|
|
||||||
from src.sysid.motor.export import main as cmd_main
|
|
||||||
else:
|
|
||||||
print(f"Unknown command: {command}")
|
|
||||||
print("Available commands: capture, optimize, visualize, export")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
cmd_main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -8,10 +8,12 @@ _PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
|||||||
if _PROJECT_ROOT not in sys.path:
|
if _PROJECT_ROOT not in sys.path:
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
|
# Headless rendering on Linux servers (must be set before mujoco import).
|
||||||
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
|
# EGL renders on the GPU directly (right for NVIDIA nodes) and avoids the
|
||||||
|
# brittle OSMesa/PyOpenGL stack. Forced (not setdefault) so a stale
|
||||||
|
# `-e MUJOCO_GL=osmesa` baked into a remote task can't override it.
|
||||||
if sys.platform == "linux":
|
if sys.platform == "linux":
|
||||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
os.environ["MUJOCO_GL"] = "egl"
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import hydra.utils as hydra_utils
|
import hydra.utils as hydra_utils
|
||||||
@@ -61,7 +63,7 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|||||||
"""Initialize ClearML task with project structure and tags."""
|
"""Initialize ClearML task with project structure and tags."""
|
||||||
Task.ignore_requirements("torch")
|
Task.ignore_requirements("torch")
|
||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "rotary_cartpole")
|
||||||
runner_name = choices.get("runner", "mujoco")
|
runner_name = choices.get("runner", "mujoco")
|
||||||
training_name = choices.get("training", "ppo")
|
training_name = choices.get("training", "ppo")
|
||||||
|
|
||||||
@@ -71,14 +73,14 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|||||||
|
|
||||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||||
task.set_base_docker(
|
task.set_base_docker(
|
||||||
"registry.kube.optimize/worker-image:latest",
|
"git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||||
docker_setup_bash_script=(
|
docker_setup_bash_script=(
|
||||||
"apt-get update && apt-get install -y --no-install-recommends "
|
"apt-get update && apt-get install -y --no-install-recommends "
|
||||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
"libegl1 libgl1 libglfw3 libosmesa6 && rm -rf /var/lib/apt/lists/* "
|
||||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0"
|
||||||
),
|
),
|
||||||
docker_arguments=[
|
docker_arguments=[
|
||||||
"-e", "MUJOCO_GL=osmesa",
|
"-e", "MUJOCO_GL=egl",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,7 +113,7 @@ def main(cfg: DictConfig) -> None:
|
|||||||
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||||
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "rotary_cartpole")
|
||||||
env = build_env(env_name, cfg)
|
env = build_env(env_name, cfg)
|
||||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||||
trainer_config = TrainerConfig(**training_dict)
|
trainer_config = TrainerConfig(**training_dict)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Usage (simulation):
|
Usage (simulation):
|
||||||
mjpython scripts/viz.py env=rotary_cartpole
|
mjpython scripts/viz.py env=rotary_cartpole
|
||||||
mjpython scripts/viz.py env=cartpole +com=true
|
mjpython scripts/viz.py env=rotary_cartpole +com=true
|
||||||
|
|
||||||
Usage (real hardware — digital twin):
|
Usage (real hardware — digital twin):
|
||||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
||||||
@@ -104,7 +104,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
|||||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
choices = HydraConfig.get().runtime.choices
|
choices = HydraConfig.get().runtime.choices
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "rotary_cartpole")
|
||||||
runner_name = choices.get("runner", "mujoco")
|
runner_name = choices.get("runner", "mujoco")
|
||||||
|
|
||||||
if runner_name == "serial":
|
if runner_name == "serial":
|
||||||
@@ -229,11 +229,12 @@ def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
|||||||
_reset_flag[0] = False
|
_reset_flag[0] = False
|
||||||
serial_runner._send("M0")
|
serial_runner._send("M0")
|
||||||
serial_runner._drive_to_center()
|
serial_runner._drive_to_center()
|
||||||
serial_runner._wait_for_pendulum_still()
|
serial_runner._wait_for_settle()
|
||||||
logger.info("reset (drive-to-center + settle)")
|
logger.info("reset (drive-to-center + settle)")
|
||||||
|
|
||||||
# Send motor command to real hardware.
|
# Send motor command to real hardware (same PWM scaling as
|
||||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
# the policy path: ctrl_range-limited).
|
||||||
|
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
|
||||||
serial_runner._send(f"M{motor_speed}")
|
serial_runner._send(f"M{motor_speed}")
|
||||||
|
|
||||||
# Sync MuJoCo model with real sensor data.
|
# Sync MuJoCo model with real sensor data.
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import abc
|
import abc
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import TypeVar, Generic, Any
|
from typing import TypeVar, Generic, Any
|
||||||
from gymnasium import spaces
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
from src.core.robot import RobotConfig, load_robot_config
|
from src.core.robot import RobotConfig, load_robot_config
|
||||||
|
|
||||||
@@ -38,7 +40,9 @@ class BaseEnv(abc.ABC, Generic[T]):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor:
|
def compute_rewards(
|
||||||
|
self, state: Any, actions: torch.Tensor, prev_actions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -48,6 +52,21 @@ class BaseEnv(abc.ABC, Generic[T]):
|
|||||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||||
return step_counts >= self.config.max_steps
|
return step_counts >= self.config.max_steps
|
||||||
|
|
||||||
|
def initial_state_ranges(
|
||||||
|
self, nq: int, nv: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Per-DOF uniform ranges for initial-state randomization.
|
||||||
|
|
||||||
|
Returns (qpos_lo, qpos_hi, qvel_lo, qvel_hi) — offsets added to the
|
||||||
|
model's default state on every reset. All runners (CPU MuJoCo and
|
||||||
|
MJX) sample from these, so initial-state distributions stay
|
||||||
|
identical across backends. Default: small ±0.05 perturbation.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
np.full(nq, -0.05), np.full(nq, 0.05),
|
||||||
|
np.full(nv, -0.05), np.full(nv, 0.05),
|
||||||
|
)
|
||||||
|
|
||||||
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
|
def is_reset_ready(self, qpos: torch.Tensor, qvel: torch.Tensor) -> bool:
|
||||||
"""Check whether the physical robot has settled enough to start an episode.
|
"""Check whether the physical robot has settled enough to start an episode.
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,10 @@
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from src.core.env import BaseEnv, BaseEnvConfig
|
from src.core.env import BaseEnv, BaseEnvConfig
|
||||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
|
||||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||||
|
|
||||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
|
||||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
log = structlog.get_logger()
|
log = structlog.get_logger()
|
||||||
@@ -51,6 +52,9 @@ class ActuatorConfig:
|
|||||||
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
|
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
|
||||||
viscous_quadratic: float = 0.0 # velocity² drag coefficient
|
viscous_quadratic: float = 0.0 # velocity² drag coefficient
|
||||||
back_emf_gain: float = 0.0 # back-EMF torque reduction
|
back_emf_gain: float = 0.0 # back-EMF torque reduction
|
||||||
|
stribeck_friction_boost: float = 0.0 # extra static friction at low speed (N·m)
|
||||||
|
stribeck_vel: float = 2.0 # Stribeck decay velocity (rad/s)
|
||||||
|
action_bias: float = 0.0 # additive ctrl bias (driver asymmetry)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gear_avg(self) -> float:
|
def gear_avg(self) -> float:
|
||||||
@@ -66,10 +70,23 @@ class ActuatorConfig:
|
|||||||
or self.frictionloss != (0.0, 0.0)
|
or self.frictionloss != (0.0, 0.0)
|
||||||
or self.viscous_quadratic > 0
|
or self.viscous_quadratic > 0
|
||||||
or self.back_emf_gain > 0
|
or self.back_emf_gain > 0
|
||||||
|
or self.stribeck_friction_boost > 0
|
||||||
|
or self.action_bias != 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
def transform_ctrl(self, ctrl: float) -> float:
|
def transform_ctrl(self, ctrl: float) -> float:
|
||||||
"""Apply asymmetric deadzone and gear compensation to a scalar ctrl."""
|
"""Clip to ctrl_range, then apply bias, deadzone and gear compensation.
|
||||||
|
|
||||||
|
Must stay in lock-step with the vectorised JAX version in
|
||||||
|
``src/runners/mjx.py`` (step_fn) — sysid fits parameters against
|
||||||
|
THIS function, so any drift breaks the identified model.
|
||||||
|
"""
|
||||||
|
# Clip to ctrl_range first (mirrors firmware PWM saturation).
|
||||||
|
ctrl = max(self.ctrl_range[0], min(self.ctrl_range[1], ctrl))
|
||||||
|
|
||||||
|
# Additive driver bias (e.g. H-bridge asymmetry).
|
||||||
|
ctrl += self.action_bias
|
||||||
|
|
||||||
# Deadzone
|
# Deadzone
|
||||||
dz_pos, dz_neg = self.deadzone
|
dz_pos, dz_neg = self.deadzone
|
||||||
if ctrl >= 0 and ctrl < dz_pos:
|
if ctrl >= 0 and ctrl < dz_pos:
|
||||||
@@ -85,18 +102,31 @@ class ActuatorConfig:
|
|||||||
|
|
||||||
return ctrl
|
return ctrl
|
||||||
|
|
||||||
def compute_motor_force(self, vel: float, ctrl: float) -> float:
|
def compute_motor_force(self, vel: float, ctrl: float,
|
||||||
"""Asymmetric friction, damping, drag, back-EMF → applied torque."""
|
friction_scale: float = 1.0,
|
||||||
|
damping_scale: float = 1.0) -> float:
|
||||||
|
"""Asymmetric friction (Coulomb + Stribeck), damping, drag, back-EMF.
|
||||||
|
|
||||||
|
``friction_scale`` / ``damping_scale`` multiply the friction and
|
||||||
|
viscous-damping terms for per-env domain randomization
|
||||||
|
(1.0 = no randomization, the default used by sysid).
|
||||||
|
"""
|
||||||
torque = 0.0
|
torque = 0.0
|
||||||
|
|
||||||
# Coulomb friction (direction-dependent)
|
# Coulomb + Stribeck friction (direction-dependent). The Stribeck
|
||||||
|
# boost adds extra friction at low speed that decays as exp(-(v/vs)²)
|
||||||
|
# — crucial for cheap brushed motors near standstill.
|
||||||
fl_pos, fl_neg = self.frictionloss
|
fl_pos, fl_neg = self.frictionloss
|
||||||
if abs(vel) > 1e-6:
|
if abs(vel) > 1e-6:
|
||||||
fl = fl_pos if vel > 0 else fl_neg
|
fl = fl_pos if vel > 0 else fl_neg
|
||||||
torque -= math.copysign(fl, vel)
|
if self.stribeck_friction_boost > 0:
|
||||||
|
fl += self.stribeck_friction_boost * math.exp(
|
||||||
|
-((abs(vel) / self.stribeck_vel) ** 2)
|
||||||
|
)
|
||||||
|
torque -= math.copysign(fl * friction_scale, vel)
|
||||||
|
|
||||||
# Viscous damping (direction-dependent)
|
# Viscous damping (direction-dependent)
|
||||||
damp = self.damping[0] if vel > 0 else self.damping[1]
|
damp = (self.damping[0] if vel > 0 else self.damping[1]) * damping_scale
|
||||||
torque -= damp * vel
|
torque -= damp * vel
|
||||||
|
|
||||||
# Quadratic velocity drag
|
# Quadratic velocity drag
|
||||||
@@ -110,20 +140,26 @@ class ActuatorConfig:
|
|||||||
return max(-10.0, min(10.0, torque))
|
return max(-10.0, min(10.0, torque))
|
||||||
|
|
||||||
def transform_action(self, action):
|
def transform_action(self, action):
|
||||||
"""Vectorised deadzone + gear compensation for a torch batch."""
|
"""Vectorised clip + bias + deadzone + gear compensation (torch batch).
|
||||||
|
|
||||||
|
Must produce the same result as ``transform_ctrl`` element-wise.
|
||||||
|
"""
|
||||||
|
action = action.clamp(self.ctrl_range[0], self.ctrl_range[1])
|
||||||
|
action = action + self.action_bias
|
||||||
|
|
||||||
dz_pos, dz_neg = self.deadzone
|
dz_pos, dz_neg = self.deadzone
|
||||||
if dz_pos > 0 or dz_neg > 0:
|
if dz_pos > 0 or dz_neg > 0:
|
||||||
action = action.clone()
|
|
||||||
pos_dead = (action >= 0) & (action < dz_pos)
|
pos_dead = (action >= 0) & (action < dz_pos)
|
||||||
neg_dead = (action < 0) & (action > -dz_neg)
|
neg_dead = (action < 0) & (action > -dz_neg)
|
||||||
action[pos_dead | neg_dead] = 0.0
|
action = action.masked_fill(pos_dead | neg_dead, 0.0)
|
||||||
|
|
||||||
gear_avg = self.gear_avg
|
gear_avg = self.gear_avg
|
||||||
if gear_avg > 1e-8 and self.gear[0] != self.gear[1]:
|
if gear_avg > 1e-8 and self.gear[0] != self.gear[1]:
|
||||||
action = action.clone() if dz_pos == 0 and dz_neg == 0 else action
|
|
||||||
pos = action >= 0
|
pos = action >= 0
|
||||||
action[pos] *= self.gear[0] / gear_avg
|
action = torch.where(
|
||||||
action[~pos] *= self.gear[1] / gear_avg
|
pos, action * (self.gear[0] / gear_avg),
|
||||||
|
action * (self.gear[1] / gear_avg),
|
||||||
|
)
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
@@ -169,9 +205,18 @@ def load_robot_config(robot_dir: str | Path) -> RobotConfig:
|
|||||||
if not urdf_path.exists():
|
if not urdf_path.exists():
|
||||||
raise FileNotFoundError(f"URDF not found: {urdf_path}")
|
raise FileNotFoundError(f"URDF not found: {urdf_path}")
|
||||||
|
|
||||||
# Parse actuators
|
# Parse actuators — ignore unknown keys (newer sysid exports may add
|
||||||
|
# fields before the loader learns about them) instead of crashing.
|
||||||
|
known_fields = {f.name for f in dataclasses.fields(ActuatorConfig)}
|
||||||
actuators = []
|
actuators = []
|
||||||
for a in raw.get("actuators", []):
|
for a in raw.get("actuators", []):
|
||||||
|
unknown = set(a) - known_fields
|
||||||
|
if unknown:
|
||||||
|
log.warning(
|
||||||
|
"robot_yaml_unknown_actuator_keys",
|
||||||
|
keys=sorted(unknown), file=str(yaml_path),
|
||||||
|
)
|
||||||
|
a = {k: v for k, v in a.items() if k in known_fields}
|
||||||
if "ctrl_range" in a:
|
if "ctrl_range" in a:
|
||||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||||
for key in ("gear", "deadzone", "damping", "frictionloss"):
|
for key in ("gear", "deadzone", "damping", "frictionloss"):
|
||||||
|
|||||||
@@ -14,6 +14,19 @@ T = TypeVar("T")
|
|||||||
class BaseRunnerConfig:
|
class BaseRunnerConfig:
|
||||||
num_envs: int = 1
|
num_envs: int = 1
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
history_length: int = 0 # 0 = plain obs, >0 = append (obs, action) history
|
||||||
|
|
||||||
|
# ── Domain randomization (sim-to-real) ─────────────────────────
|
||||||
|
# Empty dict = disabled (every field below is a no-op). Supported keys:
|
||||||
|
# qpos_noise_std: float — Gaussian sensor noise on joint angles (rad)
|
||||||
|
# qvel_noise_std: float — Gaussian sensor noise on joint velocities (rad/s)
|
||||||
|
# action_delay_steps: [lo, hi] — per-env integer control-step latency
|
||||||
|
# friction_scale: [lo, hi] — per-env multiplier on Coulomb friction
|
||||||
|
# damping_scale: [lo, hi] — per-env multiplier on viscous damping
|
||||||
|
# torque_scale: [lo, hi] — per-env multiplier on applied motor torque
|
||||||
|
# With history_length > 0 the policy can implicitly infer the sampled
|
||||||
|
# dynamics from the recent (obs, action) window — end-to-end adaptation.
|
||||||
|
domain_rand: dict = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
class BaseRunner(abc.ABC, Generic[T]):
|
class BaseRunner(abc.ABC, Generic[T]):
|
||||||
def __init__(self, env: BaseEnv, config: T) -> None:
|
def __init__(self, env: BaseEnv, config: T) -> None:
|
||||||
@@ -36,6 +49,28 @@ class BaseRunner(abc.ABC, Generic[T]):
|
|||||||
self.config.num_envs, dtype=torch.long, device=self.config.device
|
self.config.num_envs, dtype=torch.long, device=self.config.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── Domain randomization (latency / sensor noise / dynamics) ─
|
||||||
|
self._setup_domain_rand()
|
||||||
|
|
||||||
|
# ── History buffer (implicit adaptation input) ────────────
|
||||||
|
self._history_len: int = getattr(self.config, "history_length", 0)
|
||||||
|
if self._history_len > 0:
|
||||||
|
obs_dim = self.observation_space.shape[0]
|
||||||
|
act_dim = self.action_space.shape[0]
|
||||||
|
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
|
||||||
|
# Ring buffer: (num_envs, history_length, obs_dim + act_dim)
|
||||||
|
self._history_buf = torch.zeros(
|
||||||
|
self.config.num_envs, self._history_len, self._history_step_dim,
|
||||||
|
device=self.config.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy obs = [raw_obs, history_flat]
|
||||||
|
from gymnasium import spaces
|
||||||
|
aug_dim = obs_dim + self._history_len * self._history_step_dim
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-torch.inf, high=torch.inf, shape=(aug_dim,),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def num_envs(self) -> int:
|
def num_envs(self) -> int:
|
||||||
@@ -56,6 +91,12 @@ class BaseRunner(abc.ABC, Generic[T]):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Reset the given envs; return FULL-batch (num_envs, nq/nv) state.
|
||||||
|
|
||||||
|
Returning the full batch (not just the reset envs) lets GPU
|
||||||
|
backends hand back zero-copy views without host synchronisation —
|
||||||
|
the caller indexes the reset rows itself.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def _sim_close(self) -> None:
|
def _sim_close(self) -> None:
|
||||||
@@ -63,43 +104,185 @@ class BaseRunner(abc.ABC, Generic[T]):
|
|||||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||||
self._offscreen_renderer.close()
|
self._offscreen_renderer.close()
|
||||||
|
|
||||||
|
# ── Domain randomization ─────────────────────────────────────
|
||||||
|
|
||||||
|
_SCALE_FIELDS = ("friction_scale", "damping_scale", "torque_scale")
|
||||||
|
|
||||||
|
def _setup_domain_rand(self) -> None:
|
||||||
|
"""Parse the domain_rand config into per-env buffers.
|
||||||
|
|
||||||
|
All buffers are no-ops when ``domain_rand`` is empty: scales are 1.0,
|
||||||
|
delay is 0 and noise std is 0.
|
||||||
|
"""
|
||||||
|
dr = dict(getattr(self.config, "domain_rand", {}) or {})
|
||||||
|
n = self.config.num_envs
|
||||||
|
dev = self.config.device
|
||||||
|
|
||||||
|
# Fixed (not per-env) Gaussian sensor noise.
|
||||||
|
self._qpos_noise_std = float(dr.get("qpos_noise_std", 0.0))
|
||||||
|
self._qvel_noise_std = float(dr.get("qvel_noise_std", 0.0))
|
||||||
|
|
||||||
|
# Per-env multiplicative dynamics scales (applied by the sim runner).
|
||||||
|
self._dr_scales: dict[str, torch.Tensor] = {
|
||||||
|
f: torch.ones(n, device=dev) for f in self._SCALE_FIELDS
|
||||||
|
}
|
||||||
|
self._dr_scale_ranges: dict[str, tuple[float, float]] = {}
|
||||||
|
for f in self._SCALE_FIELDS:
|
||||||
|
rng = dr.get(f)
|
||||||
|
if rng:
|
||||||
|
self._dr_scale_ranges[f] = (float(rng[0]), float(rng[1]))
|
||||||
|
|
||||||
|
# Per-env integer action delay (in control steps).
|
||||||
|
self._dr_delay = torch.zeros(n, dtype=torch.long, device=dev)
|
||||||
|
delay_range = dr.get("action_delay_steps")
|
||||||
|
if delay_range:
|
||||||
|
self._delay_range = (int(delay_range[0]), int(delay_range[1]))
|
||||||
|
self._max_delay = int(delay_range[1])
|
||||||
|
else:
|
||||||
|
self._delay_range = (0, 0)
|
||||||
|
self._max_delay = 0
|
||||||
|
|
||||||
|
# Action-delay ring buffer: (num_envs, max_delay + 1, act_dim).
|
||||||
|
if self._max_delay > 0:
|
||||||
|
act_dim = self.env.action_space.shape[0]
|
||||||
|
self._action_buf = torch.zeros(
|
||||||
|
n, self._max_delay + 1, act_dim, device=dev,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resample_domain_rand(self, env_ids: torch.Tensor) -> None:
|
||||||
|
"""Sample fresh per-env DR factors (call on every (re)set)."""
|
||||||
|
if env_ids.numel() == 0:
|
||||||
|
return
|
||||||
|
dev = self.config.device
|
||||||
|
for name, (lo, hi) in self._dr_scale_ranges.items():
|
||||||
|
vals = torch.rand(env_ids.numel(), device=dev) * (hi - lo) + lo
|
||||||
|
self._dr_scales[name][env_ids] = vals
|
||||||
|
if self._max_delay > 0:
|
||||||
|
self._dr_delay[env_ids] = torch.randint(
|
||||||
|
self._delay_range[0], self._delay_range[1] + 1,
|
||||||
|
(env_ids.numel(),), device=dev,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reset_action_buffer(self, env_ids: torch.Tensor) -> None:
|
||||||
|
if self._max_delay > 0:
|
||||||
|
self._action_buf[env_ids] = 0.0
|
||||||
|
|
||||||
|
def _apply_action_delay(self, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return the per-env delayed action that the simulator should apply.
|
||||||
|
|
||||||
|
The policy's commanded action is what gets stored in history; only
|
||||||
|
the action handed to ``_sim_step`` is delayed.
|
||||||
|
"""
|
||||||
|
if self._max_delay <= 0:
|
||||||
|
return actions
|
||||||
|
self._action_buf = torch.roll(self._action_buf, 1, dims=1)
|
||||||
|
self._action_buf[:, 0] = actions
|
||||||
|
idx = torch.arange(self.num_envs, device=self.device)
|
||||||
|
return self._action_buf[idx, self._dr_delay]
|
||||||
|
|
||||||
|
def _add_sensor_noise(
|
||||||
|
self, qpos: torch.Tensor, qvel: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self._qpos_noise_std > 0:
|
||||||
|
qpos = qpos + torch.randn_like(qpos) * self._qpos_noise_std
|
||||||
|
if self._qvel_noise_std > 0:
|
||||||
|
qvel = qvel + torch.randn_like(qvel) * self._qvel_noise_std
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
|
def _compute_obs(self, qpos: torch.Tensor, qvel: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Observation the policy sees — built from noisy (sensor) state."""
|
||||||
|
nqpos, nqvel = self._add_sensor_noise(qpos, qvel)
|
||||||
|
return self.env.compute_observations(self.env.build_state(nqpos, nqvel))
|
||||||
|
|
||||||
|
# ── Observation augmentation ─────────────────────────────────
|
||||||
|
|
||||||
|
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Append the flattened (obs, action) history when enabled."""
|
||||||
|
if self._history_len <= 0:
|
||||||
|
return obs
|
||||||
|
hist_flat = self._history_buf.reshape(obs.shape[0], -1)
|
||||||
|
return torch.cat([obs, hist_flat], dim=-1)
|
||||||
|
|
||||||
|
def _push_history(self, obs: torch.Tensor, actions: torch.Tensor,
|
||||||
|
env_ids: torch.Tensor | None = None) -> None:
|
||||||
|
"""Push (obs, action) into the ring buffer (shift left, append right)."""
|
||||||
|
if self._history_len <= 0:
|
||||||
|
return
|
||||||
|
step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1)
|
||||||
|
if env_ids is None:
|
||||||
|
# All envs.
|
||||||
|
self._history_buf = torch.roll(self._history_buf, -1, dims=1)
|
||||||
|
self._history_buf[:, -1] = step
|
||||||
|
else:
|
||||||
|
self._history_buf[env_ids] = torch.roll(
|
||||||
|
self._history_buf[env_ids], -1, dims=1
|
||||||
|
)
|
||||||
|
self._history_buf[env_ids, -1] = step[env_ids]
|
||||||
|
|
||||||
|
def _reset_history(self, env_ids: torch.Tensor) -> None:
|
||||||
|
"""Zero the history buffer for reset envs."""
|
||||||
|
if self._history_len > 0:
|
||||||
|
self._history_buf[env_ids] = 0.0
|
||||||
|
|
||||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||||
|
self._resample_domain_rand(all_ids)
|
||||||
|
self._reset_action_buffer(all_ids)
|
||||||
qpos, qvel = self._sim_reset(all_ids)
|
qpos, qvel = self._sim_reset(all_ids)
|
||||||
self.step_counts.zero_()
|
self.step_counts.zero_()
|
||||||
|
self._reset_history(all_ids)
|
||||||
|
|
||||||
state = self.env.build_state(qpos, qvel)
|
obs = self._compute_obs(qpos, qvel)
|
||||||
obs = self.env.compute_observations(state)
|
return self._augment_obs(obs), {}
|
||||||
return obs, {}
|
|
||||||
|
|
||||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||||
|
prev_actions = (
|
||||||
|
self._last_actions
|
||||||
|
if self._last_actions is not None
|
||||||
|
else torch.zeros_like(actions)
|
||||||
|
)
|
||||||
self._last_actions = actions
|
self._last_actions = actions
|
||||||
qpos, qvel = self._sim_step(actions)
|
# Latency: the simulator applies a (per-env) delayed action.
|
||||||
|
sim_actions = self._apply_action_delay(actions)
|
||||||
|
qpos, qvel = self._sim_step(sim_actions)
|
||||||
self.step_counts += 1
|
self.step_counts += 1
|
||||||
|
|
||||||
state = self.env.build_state(qpos, qvel)
|
# Reward / termination use the TRUE state (no sensor noise) so the
|
||||||
obs = self.env.compute_observations(state)
|
# learning signal and safety checks stay clean.
|
||||||
rewards = self.env.compute_rewards(state, actions)
|
clean_state = self.env.build_state(qpos, qvel)
|
||||||
terminated = self.env.compute_terminations(state)
|
rewards = self.env.compute_rewards(clean_state, actions, prev_actions)
|
||||||
|
terminated = self.env.compute_terminations(clean_state)
|
||||||
truncated = self.env.compute_truncations(self.step_counts)
|
truncated = self.env.compute_truncations(self.step_counts)
|
||||||
|
|
||||||
|
# The observation the policy sees is built from the NOISY sensor state.
|
||||||
|
obs = self._compute_obs(qpos, qvel)
|
||||||
|
|
||||||
|
# Push current (obs, action) into history before augmenting.
|
||||||
|
self._push_history(obs, actions)
|
||||||
|
|
||||||
info: dict[str, Any] = {}
|
info: dict[str, Any] = {}
|
||||||
|
|
||||||
done = terminated | truncated
|
done = terminated | truncated
|
||||||
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
||||||
|
|
||||||
if done_ids.numel() > 0:
|
if done_ids.numel() > 0:
|
||||||
info["final_observations"] = obs[done_ids].clone()
|
info["final_observations"] = self._augment_obs(obs)[done_ids].clone()
|
||||||
info["final_env_ids"] = done_ids.clone()
|
info["final_env_ids"] = done_ids.clone()
|
||||||
|
|
||||||
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
# New episode → fresh dynamics + cleared latency buffer.
|
||||||
|
self._resample_domain_rand(done_ids)
|
||||||
|
self._reset_action_buffer(done_ids)
|
||||||
|
full_qpos, full_qvel = self._sim_reset(done_ids)
|
||||||
self.step_counts[done_ids] = 0
|
self.step_counts[done_ids] = 0
|
||||||
|
self._reset_history(done_ids)
|
||||||
|
|
||||||
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
# _sim_reset returns the full batch — index the reset rows here.
|
||||||
obs[done_ids] = self.env.compute_observations(reset_state)
|
obs[done_ids] = self._compute_obs(
|
||||||
|
full_qpos[done_ids], full_qvel[done_ids],
|
||||||
|
)
|
||||||
|
|
||||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||||
|
|
||||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||||
"""Return a raw RGB frame. Override in subclass."""
|
"""Return a raw RGB frame. Override in subclass."""
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
import torch
|
|
||||||
from src.core.env import BaseEnv, BaseEnvConfig
|
|
||||||
from gymnasium import spaces
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class CartPoleState:
|
|
||||||
cart_pos: torch.float # (num_envs,)
|
|
||||||
cart_vel: torch.float # (num_envs,)
|
|
||||||
pole_angle: torch.float # (num_envs,)
|
|
||||||
pole_vel: torch.float # (num_envs,)
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class CartPoleConfig(BaseEnvConfig):
|
|
||||||
"""CartPole task config. All values come from Hydra YAML."""
|
|
||||||
angle_threshold: float = 0.418 # ~24 degrees
|
|
||||||
cart_limit: float = 2.4
|
|
||||||
reward_alive: float = 1.0
|
|
||||||
reward_pole_upright_scale: float = 1.0
|
|
||||||
reward_action_penalty_scale: float = 0.01
|
|
||||||
|
|
||||||
class CartPoleEnv(BaseEnv[CartPoleConfig]):
|
|
||||||
def __init__(self, config: CartPoleConfig):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def observation_space(self) -> torch.Tensor:
|
|
||||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_space(self) -> torch.Tensor:
|
|
||||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
|
||||||
|
|
||||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
|
|
||||||
return CartPoleState(
|
|
||||||
cart_pos=qpos[:, 0],
|
|
||||||
cart_vel=qvel[:, 0],
|
|
||||||
pole_angle=qpos[:, 1],
|
|
||||||
pole_vel=qvel[:, 1],
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
|
|
||||||
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
|
|
||||||
|
|
||||||
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
|
||||||
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
|
|
||||||
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
|
|
||||||
return self.config.reward_alive + upright - action_penalty
|
|
||||||
|
|
||||||
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
|
|
||||||
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
|
|
||||||
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
|
|
||||||
return pole_fallen | cart_out_of_bounds
|
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from src.core.env import BaseEnv, BaseEnvConfig
|
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from src.core.env import BaseEnv, BaseEnvConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class RotaryCartPoleState:
|
class RotaryCartPoleState:
|
||||||
@@ -21,10 +24,18 @@ class RotaryCartPoleConfig(BaseEnvConfig):
|
|||||||
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||||
"""
|
"""
|
||||||
# Reward shaping
|
# Reward shaping
|
||||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
reward_upright_scale: float = 1.0 # upright reward ∈ [0, scale]
|
||||||
|
alive_bonus: float = 0.25 # per-step survival bonus (must stay alive > die)
|
||||||
|
balance_bonus: float = 2.0 # extra reward for upright AND still (beats spinning)
|
||||||
|
balance_vel_scale: float = 0.5 # decay rate of the bonus with pendulum speed
|
||||||
motor_vel_penalty: float = 0.01 # penalise high motor angular velocity
|
motor_vel_penalty: float = 0.01 # penalise high motor angular velocity
|
||||||
motor_angle_penalty: float = 0.05 # penalise deviation from centre
|
motor_angle_penalty: float = 0.05 # penalise deviation from centre
|
||||||
action_penalty: float = 0.05 # penalise large actions (energy cost)
|
action_penalty: float = 0.05 # penalise large actions (energy cost)
|
||||||
|
action_rate_penalty: float = 0.01 # penalise action changes (smoothness —
|
||||||
|
# critical with ~100 ms real motor lag)
|
||||||
|
|
||||||
|
# ── Initial state randomisation ──
|
||||||
|
pendulum_init_range_deg: float = 180.0 # pendulum starts in [-range, +range]
|
||||||
|
|
||||||
# ── Software safety limit (env-level, on top of URDF + hardware) ──
|
# ── Software safety limit (env-level, on top of URDF + hardware) ──
|
||||||
motor_angle_limit_deg: float = 90.0 # terminate episode if exceeded
|
motor_angle_limit_deg: float = 90.0 # terminate episode if exceeded
|
||||||
@@ -81,9 +92,29 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|||||||
|
|
||||||
# ── Rewards ──────────────────────────────────────────────────
|
# ── Rewards ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
def compute_rewards(
|
||||||
# Upright reward: -cos(θ) ∈ [-1, +1]
|
self,
|
||||||
reward = -torch.cos(state.pendulum_angle) * self.config.reward_upright_scale
|
state: RotaryCartPoleState,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
prev_actions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Upright shaping ∈ [0, 1]: 0 hanging down (θ=0), 1 fully upright (θ=π).
|
||||||
|
# Non-negative by design so *surviving* always beats ending the episode early
|
||||||
|
# (otherwise the optimum is to slam the arm into the ±limit — "suicide policy").
|
||||||
|
upright = 0.5 * (1.0 - torch.cos(state.pendulum_angle))
|
||||||
|
|
||||||
|
# Balanced bonus: large ONLY when near the top AND nearly still. A freely
|
||||||
|
# spinning pendulum passes through the top at high speed, so stillness≈0 and
|
||||||
|
# it earns ~none of this — making true balancing strictly dominate the
|
||||||
|
# "just keep spinning in full loops" local optimum.
|
||||||
|
stillness = torch.exp(-self.config.balance_vel_scale * state.pendulum_vel.pow(2))
|
||||||
|
balance = self.config.balance_bonus * upright * stillness
|
||||||
|
|
||||||
|
# Per-step alive bonus keeps a not-yet-upright step net-positive despite
|
||||||
|
# penalties, so the −10 termination penalty is genuinely a deterrent.
|
||||||
|
reward = (upright * self.config.reward_upright_scale
|
||||||
|
+ balance
|
||||||
|
+ self.config.alive_bonus)
|
||||||
|
|
||||||
# Penalise fast motor spinning (discourages violent oscillation)
|
# Penalise fast motor spinning (discourages violent oscillation)
|
||||||
reward = reward - self.config.motor_vel_penalty * state.motor_vel.pow(2)
|
reward = reward - self.config.motor_vel_penalty * state.motor_vel.pow(2)
|
||||||
@@ -94,13 +125,34 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|||||||
# Penalise large actions (energy efficiency / smoother control)
|
# Penalise large actions (energy efficiency / smoother control)
|
||||||
reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2)
|
reward = reward - self.config.action_penalty * actions.squeeze(-1).pow(2)
|
||||||
|
|
||||||
|
# Penalise rapid action changes — a jittery policy is unrealisable
|
||||||
|
# through the real motor's ~100 ms lag and excites unmodeled dynamics.
|
||||||
|
action_rate = (actions - prev_actions).squeeze(-1).pow(2)
|
||||||
|
reward = reward - self.config.action_rate_penalty * action_rate
|
||||||
|
|
||||||
# Penalty for exceeding motor angle limit (episode also terminates)
|
# Penalty for exceeding motor angle limit (episode also terminates)
|
||||||
limit_rad = math.radians(self.config.motor_angle_limit_deg)
|
limit_rad = math.radians(self.config.motor_angle_limit_deg)
|
||||||
exceeded = state.motor_angle.abs() >= limit_rad
|
exceeded = state.motor_angle.abs() >= limit_rad
|
||||||
reward = torch.where(exceeded, torch.tensor(-1000.0, device=reward.device), reward)
|
reward = torch.where(exceeded, torch.tensor(-10.0, device=reward.device), reward)
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
# ── Initial state randomization ──────────────────────────────
|
||||||
|
|
||||||
|
def initial_state_ranges(
|
||||||
|
self, nq: int, nv: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Small motor perturbation; wide pendulum angle (swing-up task)."""
|
||||||
|
qpos_lo = np.full(nq, -0.05)
|
||||||
|
qpos_hi = np.full(nq, 0.05)
|
||||||
|
qvel_lo = np.full(nv, -0.05)
|
||||||
|
qvel_hi = np.full(nv, 0.05)
|
||||||
|
pend_range = math.radians(self.config.pendulum_init_range_deg)
|
||||||
|
if pend_range > 0 and nq >= 2:
|
||||||
|
qpos_lo[1] = -pend_range
|
||||||
|
qpos_hi[1] = pend_range
|
||||||
|
return qpos_lo, qpos_hi, qvel_lo, qvel_hi
|
||||||
|
|
||||||
# ── Terminations ─────────────────────────────────────────────
|
# ── Terminations ─────────────────────────────────────────────
|
||||||
|
|
||||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||||
|
|||||||
@@ -3,14 +3,95 @@ import torch.nn as nn
|
|||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEncoder(nn.Module):
|
||||||
|
"""1D-CNN encoder over a temporal window of (obs, action) pairs.
|
||||||
|
|
||||||
|
Input: (batch, history_length, step_dim)
|
||||||
|
Output: (batch, embedding_dim)
|
||||||
|
|
||||||
|
Architecture: two temporal conv layers → global average pool → linear.
|
||||||
|
Lets the policy implicitly infer the current dynamics (friction, torque
|
||||||
|
scale, latency, …) from how the system responded to recent actions —
|
||||||
|
end-to-end adaptation when trained under domain randomization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
history_length: int,
|
||||||
|
step_dim: int,
|
||||||
|
embedding_dim: int = 32,
|
||||||
|
hidden_channels: int = 32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
# (batch, step_dim, history_length) after transpose
|
||||||
|
nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
nn.ELU(),
|
||||||
|
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
|
||||||
|
nn.ELU(),
|
||||||
|
)
|
||||||
|
self.fc = nn.Linear(hidden_channels, embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, history: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""history: (batch, history_length, step_dim)."""
|
||||||
|
# Conv1d expects (batch, channels, seq_len).
|
||||||
|
x = history.transpose(1, 2)
|
||||||
|
x = self.conv(x)
|
||||||
|
# Global average pool over time.
|
||||||
|
x = x.mean(dim=-1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
|
||||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
"""Shared policy/value network with an optional history encoder.
|
||||||
|
|
||||||
|
With ``history_length > 0`` the input states are expected to be
|
||||||
|
``[raw_obs, history_flat]`` (as produced by ``BaseRunner``); the history
|
||||||
|
window is compressed by a :class:`HistoryEncoder` and concatenated with
|
||||||
|
the raw observation before the shared backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observation_space: spaces.Space,
|
||||||
|
action_space: spaces.Space,
|
||||||
|
device: torch.device,
|
||||||
|
hidden_sizes: tuple[int, ...] = (32, 32),
|
||||||
|
clip_actions: bool = False,
|
||||||
|
clip_log_std: bool = True,
|
||||||
|
min_log_std: float = -2.0,
|
||||||
|
max_log_std: float = 2.0,
|
||||||
|
initial_log_std: float = 0.0,
|
||||||
|
# ── History encoder ──────────────────────────────────────
|
||||||
|
history_length: int = 0,
|
||||||
|
raw_obs_dim: int = 0,
|
||||||
|
embedding_dim: int = 32,
|
||||||
|
):
|
||||||
Model.__init__(self, observation_space, action_space, device)
|
Model.__init__(self, observation_space, action_space, device)
|
||||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
GaussianMixin.__init__(
|
||||||
|
self, clip_actions, clip_log_std, min_log_std, max_log_std,
|
||||||
|
)
|
||||||
DeterministicMixin.__init__(self, clip_actions)
|
DeterministicMixin.__init__(self, clip_actions)
|
||||||
|
|
||||||
layers = []
|
self._history_length = history_length
|
||||||
in_dim: int = self.num_observations
|
self._raw_obs_dim = raw_obs_dim
|
||||||
|
self._embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
self.history_encoder: HistoryEncoder | None = None
|
||||||
|
if history_length > 0 and raw_obs_dim > 0:
|
||||||
|
step_dim = raw_obs_dim + self.num_actions
|
||||||
|
self.history_encoder = HistoryEncoder(
|
||||||
|
history_length=history_length,
|
||||||
|
step_dim=step_dim,
|
||||||
|
embedding_dim=embedding_dim,
|
||||||
|
)
|
||||||
|
in_dim = raw_obs_dim + embedding_dim
|
||||||
|
else:
|
||||||
|
in_dim = self.num_observations
|
||||||
|
|
||||||
|
# ── Shared backbone ──────────────────────────────────────
|
||||||
|
layers: list[nn.Module] = []
|
||||||
for hidden_size in hidden_sizes:
|
for hidden_size in hidden_sizes:
|
||||||
layers.append(nn.Linear(in_dim, hidden_size))
|
layers.append(nn.Linear(in_dim, hidden_size))
|
||||||
layers.append(nn.ELU())
|
layers.append(nn.ELU())
|
||||||
@@ -19,30 +100,45 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
|||||||
|
|
||||||
# Policy head
|
# Policy head
|
||||||
self.mean_layer = nn.Linear(in_dim, self.num_actions)
|
self.mean_layer = nn.Linear(in_dim, self.num_actions)
|
||||||
self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std))
|
self.log_std_parameter: nn.Parameter = nn.Parameter(
|
||||||
|
torch.full((self.num_actions,), initial_log_std),
|
||||||
|
)
|
||||||
|
|
||||||
# Value head
|
# Value head
|
||||||
self.value_layer = nn.Linear(in_dim, 1)
|
self.value_layer = nn.Linear(in_dim, 1)
|
||||||
self._shared_output: torch.Tensor | None = None
|
self._shared_output: torch.Tensor | None = None
|
||||||
|
|
||||||
|
def act(
|
||||||
def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]:
|
self, inputs: dict[str, torch.Tensor], role: str = "",
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
if role == "policy":
|
if role == "policy":
|
||||||
return GaussianMixin.act(self, inputs, role)
|
return GaussianMixin.act(self, inputs, role)
|
||||||
elif role == "value":
|
elif role == "value":
|
||||||
return DeterministicMixin.act(self, inputs, role)
|
return DeterministicMixin.act(self, inputs, role)
|
||||||
|
|
||||||
|
def _encode(self, states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Optionally split off and encode the history window."""
|
||||||
|
if self.history_encoder is None:
|
||||||
|
return self.net(states)
|
||||||
|
|
||||||
|
obs = states[:, :self._raw_obs_dim]
|
||||||
|
hist_flat = states[:, self._raw_obs_dim:]
|
||||||
|
step_dim = self._raw_obs_dim + self.num_actions
|
||||||
|
history = hist_flat.reshape(-1, self._history_length, step_dim)
|
||||||
|
embedding = self.history_encoder(history)
|
||||||
|
return self.net(torch.cat([obs, embedding], dim=-1))
|
||||||
|
|
||||||
def compute(
|
def compute(
|
||||||
self, inputs: dict[str, torch.Tensor], role: str = ""
|
self, inputs: dict[str, torch.Tensor], role: str = "",
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
if role == "policy":
|
if role == "policy":
|
||||||
self._shared_output = self.net(inputs["states"])
|
self._shared_output = self._encode(inputs["states"])
|
||||||
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
|
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
|
||||||
elif role == "value":
|
elif role == "value":
|
||||||
shared_output = (
|
shared_output = (
|
||||||
self._shared_output
|
self._shared_output
|
||||||
if self._shared_output is not None
|
if self._shared_output is not None
|
||||||
else self.net(inputs["states"])
|
else self._encode(inputs["states"])
|
||||||
)
|
)
|
||||||
self._shared_output = None
|
self._shared_output = None
|
||||||
return self.value_layer(shared_output), {}
|
return self.value_layer(shared_output), {}
|
||||||
@@ -9,10 +9,17 @@ Requirements:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import os
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
|
||||||
|
# default JAX preallocates ~75% of GPU memory on init, starving torch and
|
||||||
|
# causing OOM at the first PPO update. Disable preallocation so both libraries
|
||||||
|
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
|
||||||
|
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -81,7 +88,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
mujoco.mj_forward(self._mj_model, default_data)
|
mujoco.mj_forward(self._mj_model, default_data)
|
||||||
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
|
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
|
||||||
|
|
||||||
# Step 4: Initialise all environments with small perturbations
|
# Env-defined initial-state distribution (shared with the CPU
|
||||||
|
# runner) — baked into the JIT reset as constants.
|
||||||
|
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
|
||||||
|
self._nq, self._nv,
|
||||||
|
)
|
||||||
|
self._init_qpos_lo = jnp.asarray(qpos_lo)
|
||||||
|
self._init_qpos_hi = jnp.asarray(qpos_hi)
|
||||||
|
self._init_qvel_lo = jnp.asarray(qvel_lo)
|
||||||
|
self._init_qvel_hi = jnp.asarray(qvel_hi)
|
||||||
|
|
||||||
|
# Step 4: Initialise all environments with randomized states
|
||||||
self._rng = jax.random.PRNGKey(42)
|
self._rng = jax.random.PRNGKey(42)
|
||||||
self._batch_data = self._make_batched_data(config.num_envs)
|
self._batch_data = self._make_batched_data(config.num_envs)
|
||||||
|
|
||||||
@@ -103,6 +120,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
# Keep one CPU MjData for offscreen rendering
|
# Keep one CPU MjData for offscreen rendering
|
||||||
self._render_data = mujoco.MjData(self._mj_model)
|
self._render_data = mujoco.MjData(self._mj_model)
|
||||||
|
|
||||||
|
# Per-env DR scale arrays (synced from torch on every reset).
|
||||||
|
# Initialised to 1.0 here because _setup_domain_rand runs after this.
|
||||||
|
self._mjx_fr = jnp.ones(config.num_envs)
|
||||||
|
self._mjx_dp = jnp.ones(config.num_envs)
|
||||||
|
self._mjx_tq = jnp.ones(config.num_envs)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"mjx_runner_ready",
|
"mjx_runner_ready",
|
||||||
num_envs=config.num_envs,
|
num_envs=config.num_envs,
|
||||||
@@ -111,10 +134,16 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _make_batched_data(self, n: int):
|
def _make_batched_data(self, n: int):
|
||||||
"""Create *n* environments with small random perturbations."""
|
"""Create *n* environments with env-defined initial randomization."""
|
||||||
self._rng, k1, k2 = jax.random.split(self._rng, 3)
|
self._rng, k1, k2 = jax.random.split(self._rng, 3)
|
||||||
pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05)
|
pq = jax.random.uniform(
|
||||||
pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05)
|
k1, (n, self._nq),
|
||||||
|
minval=self._init_qpos_lo, maxval=self._init_qpos_hi,
|
||||||
|
)
|
||||||
|
pv = jax.random.uniform(
|
||||||
|
k2, (n, self._nv),
|
||||||
|
minval=self._init_qvel_lo, maxval=self._init_qvel_hi,
|
||||||
|
)
|
||||||
|
|
||||||
default = self._default_mjx_data
|
default = self._default_mjx_data
|
||||||
model = self._mjx_model
|
model = self._mjx_model
|
||||||
@@ -141,11 +170,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
act_gs = jnp.array(lim.gear_sign)
|
act_gs = jnp.array(lim.gear_sign)
|
||||||
|
|
||||||
# ── Motor model params (JAX arrays for JIT) ─────────────────
|
# ── Motor model params (JAX arrays for JIT) ─────────────────
|
||||||
|
# Must stay in lock-step with ActuatorConfig.transform_ctrl() /
|
||||||
|
# compute_motor_force() in src/core/robot.py — sysid fits against
|
||||||
|
# the CPU implementation.
|
||||||
_has_motor = len(self._motor_info) > 0
|
_has_motor = len(self._motor_info) > 0
|
||||||
if _has_motor:
|
if _has_motor:
|
||||||
acts = self._motor_acts
|
acts = self._motor_acts
|
||||||
_ctrl_ids = jnp.array([c for c, _ in self._motor_info])
|
_ctrl_ids = jnp.array([c for c, _ in self._motor_info])
|
||||||
_qvel_ids = jnp.array([q for _, q in self._motor_info])
|
_qvel_ids = jnp.array([q for _, q in self._motor_info])
|
||||||
|
_ctrl_lo = jnp.array([a.ctrl_range[0] for a in acts])
|
||||||
|
_ctrl_hi = jnp.array([a.ctrl_range[1] for a in acts])
|
||||||
|
_bias = jnp.array([a.action_bias for a in acts])
|
||||||
_dz_pos = jnp.array([a.deadzone[0] for a in acts])
|
_dz_pos = jnp.array([a.deadzone[0] for a in acts])
|
||||||
_dz_neg = jnp.array([a.deadzone[1] for a in acts])
|
_dz_neg = jnp.array([a.deadzone[1] for a in acts])
|
||||||
_gear_pos = jnp.array([a.gear[0] for a in acts])
|
_gear_pos = jnp.array([a.gear[0] for a in acts])
|
||||||
@@ -153,14 +188,23 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
_gear_avg = jnp.array([a.gear_avg for a in acts])
|
_gear_avg = jnp.array([a.gear_avg for a in acts])
|
||||||
_fl_pos = jnp.array([a.frictionloss[0] for a in acts])
|
_fl_pos = jnp.array([a.frictionloss[0] for a in acts])
|
||||||
_fl_neg = jnp.array([a.frictionloss[1] for a in acts])
|
_fl_neg = jnp.array([a.frictionloss[1] for a in acts])
|
||||||
|
_strb_boost = jnp.array([a.stribeck_friction_boost for a in acts])
|
||||||
|
_strb_vel = jnp.array([a.stribeck_vel for a in acts])
|
||||||
_damp_pos = jnp.array([a.damping[0] for a in acts])
|
_damp_pos = jnp.array([a.damping[0] for a in acts])
|
||||||
_damp_neg = jnp.array([a.damping[1] for a in acts])
|
_damp_neg = jnp.array([a.damping[1] for a in acts])
|
||||||
_visc_quad = jnp.array([a.viscous_quadratic for a in acts])
|
_visc_quad = jnp.array([a.viscous_quadratic for a in acts])
|
||||||
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
||||||
|
|
||||||
# ── Batched step (N substeps per call) ──────────────────────
|
# ── Batched step (N substeps per call) ──────────────────────
|
||||||
|
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
|
||||||
|
# multipliers (1.0 = off). Passed as args (not closure constants) so
|
||||||
|
# resampling them every episode does NOT trigger JIT recompilation.
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def step_fn(data):
|
def step_fn(data, fr_scale, dp_scale, tq_scale):
|
||||||
|
fr = fr_scale[:, None] # broadcast over motor actuators
|
||||||
|
dp = dp_scale[:, None]
|
||||||
|
tq = tq_scale[:, None]
|
||||||
|
|
||||||
# Software limit switch: clamp ctrl once before substeps.
|
# Software limit switch: clamp ctrl once before substeps.
|
||||||
pos = data.qpos[:, act_jnt_ids]
|
pos = data.qpos[:, act_jnt_ids]
|
||||||
ctrl = data.ctrl
|
ctrl = data.ctrl
|
||||||
@@ -169,12 +213,16 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl)
|
ctrl = jnp.where(at_hi | at_lo, 0.0, ctrl)
|
||||||
|
|
||||||
if _has_motor:
|
if _has_motor:
|
||||||
# Deadzone + asymmetric gear compensation
|
# Clip → bias → deadzone → asymmetric gear compensation
|
||||||
|
# (same order as ActuatorConfig.transform_ctrl).
|
||||||
mc = ctrl[:, _ctrl_ids]
|
mc = ctrl[:, _ctrl_ids]
|
||||||
|
mc = jnp.clip(mc, _ctrl_lo, _ctrl_hi)
|
||||||
|
mc = mc + _bias
|
||||||
mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc)
|
mc = jnp.where((mc >= 0) & (mc < _dz_pos), 0.0, mc)
|
||||||
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
||||||
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
||||||
mc = mc * gear_dir / _gear_avg
|
mc = mc * gear_dir / _gear_avg
|
||||||
|
mc = mc * tq # torque_scale (DR)
|
||||||
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
||||||
|
|
||||||
data = data.replace(ctrl=ctrl)
|
data = data.replace(ctrl=ctrl)
|
||||||
@@ -184,13 +232,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
vel = d.qvel[:, _qvel_ids]
|
vel = d.qvel[:, _qvel_ids]
|
||||||
mc = d.ctrl[:, _ctrl_ids]
|
mc = d.ctrl[:, _ctrl_ids]
|
||||||
|
|
||||||
# Coulomb friction (direction-dependent)
|
# Coulomb + Stribeck friction (direction-dependent) × DR
|
||||||
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
|
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
|
||||||
|
fl = fl + _strb_boost * jnp.exp(
|
||||||
|
-((jnp.abs(vel) / _strb_vel) ** 2)
|
||||||
|
)
|
||||||
|
fl = fl * fr
|
||||||
torque = -jnp.where(
|
torque = -jnp.where(
|
||||||
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
||||||
)
|
)
|
||||||
# Viscous damping (direction-dependent)
|
# Viscous damping (direction-dependent) × DR scale
|
||||||
damp = jnp.where(vel > 0, _damp_pos, _damp_neg)
|
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
|
||||||
torque = torque - damp * vel
|
torque = torque - damp * vel
|
||||||
# Quadratic velocity drag
|
# Quadratic velocity drag
|
||||||
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
||||||
@@ -211,16 +263,23 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
self._jit_step = step_fn
|
self._jit_step = step_fn
|
||||||
|
|
||||||
# ── Selective reset ─────────────────────────────────────────
|
# ── Selective reset ─────────────────────────────────────────
|
||||||
|
init_qpos_lo = self._init_qpos_lo
|
||||||
|
init_qpos_hi = self._init_qpos_hi
|
||||||
|
init_qvel_lo = self._init_qvel_lo
|
||||||
|
init_qvel_hi = self._init_qvel_hi
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def reset_fn(data, mask, rng):
|
def reset_fn(data, mask, rng):
|
||||||
rng, k1, k2 = jax.random.split(rng, 3)
|
rng, k1, k2 = jax.random.split(rng, 3)
|
||||||
ne = data.qpos.shape[0]
|
ne = data.qpos.shape[0]
|
||||||
|
|
||||||
pq = jax.random.uniform(
|
pq = jax.random.uniform(
|
||||||
k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05,
|
k1, (ne, default.qpos.shape[0]),
|
||||||
|
minval=init_qpos_lo, maxval=init_qpos_hi,
|
||||||
)
|
)
|
||||||
pv = jax.random.uniform(
|
pv = jax.random.uniform(
|
||||||
k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05,
|
k2, (ne, default.qvel.shape[0]),
|
||||||
|
minval=init_qvel_lo, maxval=init_qvel_hi,
|
||||||
)
|
)
|
||||||
|
|
||||||
m = mask[:, None] # (num_envs, 1) broadcast helper
|
m = mask[:, None] # (num_envs, 1) broadcast helper
|
||||||
@@ -242,9 +301,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||||
|
|
||||||
# Set ctrl & run N substeps for all environments
|
# Set ctrl & run N substeps for all environments (with per-env DR scales)
|
||||||
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||||
self._batch_data = self._jit_step(self._batch_data)
|
self._batch_data = self._jit_step(
|
||||||
|
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
|
||||||
|
)
|
||||||
|
|
||||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||||
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||||
@@ -263,11 +324,18 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
self._batch_data, mask_jax, self._rng,
|
self._batch_data, mask_jax, self._rng,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only the reset environments' states
|
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
|
||||||
ids_np = env_ids.cpu().numpy()
|
# resamples self._dr_scales just before this call, so re-deriving the
|
||||||
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
|
# full arrays here keeps the JAX copies current for every env.
|
||||||
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
|
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
|
||||||
return torch.from_dlpack(rq), torch.from_dlpack(rv)
|
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
|
||||||
|
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
|
||||||
|
|
||||||
|
# Return the FULL batch (BaseRunner indexes the reset envs in torch)
|
||||||
|
# — avoids a GPU→CPU sync + JAX gather on every step with a done env.
|
||||||
|
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||||
|
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
# ── Rendering ────────────────────────────────────────────────────
|
# ── Rendering ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -246,14 +246,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||||
|
|
||||||
|
# Per-env domain-randomization scales (all 1.0 when DR is disabled).
|
||||||
|
fr_scale = self._dr_scales["friction_scale"].cpu().numpy()
|
||||||
|
dp_scale = self._dr_scales["damping_scale"].cpu().numpy()
|
||||||
|
tq_scale = self._dr_scales["torque_scale"].cpu().numpy()
|
||||||
|
|
||||||
for i, data in enumerate(self._data):
|
for i, data in enumerate(self._data):
|
||||||
data.ctrl[:] = actions_np[i]
|
# torque_scale emulates motor-constant / battery-voltage variation.
|
||||||
|
data.ctrl[:] = actions_np[i] * tq_scale[i]
|
||||||
for _ in range(self.config.substeps):
|
for _ in range(self.config.substeps):
|
||||||
# Apply asymmetric motor forces via qfrc_applied
|
# Apply asymmetric motor forces via qfrc_applied
|
||||||
for act, qvel_idx in self._motor_actuators:
|
for act, qvel_idx in self._motor_actuators:
|
||||||
vel = data.qvel[qvel_idx]
|
vel = data.qvel[qvel_idx]
|
||||||
ctrl = data.ctrl[0] # TODO: generalise for multi-actuator
|
ctrl = data.ctrl[0] # TODO: generalise for multi-actuator
|
||||||
data.qfrc_applied[qvel_idx] = act.compute_motor_force(vel, ctrl)
|
data.qfrc_applied[qvel_idx] = act.compute_motor_force(
|
||||||
|
vel, ctrl,
|
||||||
|
friction_scale=fr_scale[i],
|
||||||
|
damping_scale=dp_scale[i],
|
||||||
|
)
|
||||||
self._limits.enforce(self._model, data)
|
self._limits.enforce(self._model, data)
|
||||||
mujoco.mj_step(self._model, data)
|
mujoco.mj_step(self._model, data)
|
||||||
|
|
||||||
@@ -267,23 +277,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
|
|
||||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
ids = env_ids.cpu().numpy()
|
ids = env_ids.cpu().numpy()
|
||||||
n = len(ids)
|
|
||||||
|
|
||||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
# Env-defined initial-state distribution (shared with the MJX runner).
|
||||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
qpos_lo, qpos_hi, qvel_lo, qvel_hi = self.env.initial_state_ranges(
|
||||||
|
self._nq, self._nv,
|
||||||
|
)
|
||||||
|
|
||||||
for i, env_id in enumerate(ids):
|
for env_id in ids:
|
||||||
data = self._data[env_id]
|
data = self._data[env_id]
|
||||||
mujoco.mj_resetData(self._model, data)
|
mujoco.mj_resetData(self._model, data)
|
||||||
|
|
||||||
# Small random perturbation for exploration
|
data.qpos[:] += np.random.uniform(qpos_lo, qpos_hi)
|
||||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
data.qvel[:] += np.random.uniform(qvel_lo, qvel_hi)
|
||||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
|
||||||
data.ctrl[:] = 0.0
|
data.ctrl[:] = 0.0
|
||||||
|
|
||||||
qpos_batch[i] = data.qpos
|
# Full-batch return (see BaseRunner._sim_reset contract).
|
||||||
qvel_batch[i] = data.qvel
|
qpos_batch = np.stack([d.qpos for d in self._data]).astype(np.float32)
|
||||||
|
qvel_batch = np.stack([d.qvel for d in self._data]).astype(np.float32)
|
||||||
return (
|
return (
|
||||||
torch.from_numpy(qpos_batch).to(self.device),
|
torch.from_numpy(qpos_batch).to(self.device),
|
||||||
torch.from_numpy(qvel_batch).to(self.device),
|
torch.from_numpy(qvel_batch).to(self.device),
|
||||||
|
|||||||
@@ -243,8 +243,9 @@ def capture(
|
|||||||
idx = 0
|
idx = 0
|
||||||
pwm = 0
|
pwm = 0
|
||||||
last_esp_ms = -1 # firmware timestamp of last recorded sample
|
last_esp_ms = -1 # firmware timestamp of last recorded sample
|
||||||
|
esp_ms_origin: int | None = None # first firmware timestamp
|
||||||
no_data_count = 0 # consecutive timeouts with no data
|
no_data_count = 0 # consecutive timeouts with no data
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic() # host clock for duration check only
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# Block until the firmware sends the next state line (~20 ms).
|
# Block until the firmware sends the next state line (~20 ms).
|
||||||
@@ -276,7 +277,10 @@ def capture(
|
|||||||
continue
|
continue
|
||||||
last_esp_ms = esp_ms
|
last_esp_ms = esp_ms
|
||||||
|
|
||||||
elapsed = time.monotonic() - t0
|
# Use firmware clock for time axis (avoids host serial jitter).
|
||||||
|
if esp_ms_origin is None:
|
||||||
|
esp_ms_origin = esp_ms
|
||||||
|
elapsed = (esp_ms - esp_ms_origin) / 1000.0
|
||||||
if elapsed >= duration:
|
if elapsed >= duration:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -396,7 +400,7 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--amplitude", type=int, default=150,
|
"--amplitude", type=int, default=150,
|
||||||
help="Max PWM magnitude (should not exceed firmware MAX_MOTOR_SPEED=150)",
|
help="Max PWM magnitude for excitation (0-255)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
||||||
|
|||||||
@@ -29,15 +29,18 @@ def export_tuned_files(
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
robot_path : robot asset directory (contains robot.yaml + *.urdf)
|
robot_path : robot asset directory (contains robot.yaml + *.urdf)
|
||||||
params : dict of parameter name → tuned value (from optimizer)
|
params : dict of parameter name → tuned value (the optimised set)
|
||||||
motor_params : locked motor sysid params (asymmetric model).
|
motor_params : locked motor parameters merged underneath ``params``
|
||||||
If provided, motor joint parameters come from here.
|
(``params`` wins on conflicts) so the exported YAML always has a
|
||||||
|
complete motor model
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(tuned_urdf_path, tuned_robot_yaml_path)
|
(tuned_urdf_path, tuned_robot_yaml_path)
|
||||||
"""
|
"""
|
||||||
robot_path = Path(robot_path).resolve()
|
robot_path = Path(robot_path).resolve()
|
||||||
|
if motor_params:
|
||||||
|
params = {**motor_params, **params}
|
||||||
|
|
||||||
# ── Load originals ───────────────────────────────────────────
|
# ── Load originals ───────────────────────────────────────────
|
||||||
robot_yaml_path = robot_path / "robot.yaml"
|
robot_yaml_path = robot_path / "robot.yaml"
|
||||||
@@ -66,39 +69,34 @@ def export_tuned_files(
|
|||||||
# Update actuator parameters — full asymmetric motor model.
|
# Update actuator parameters — full asymmetric motor model.
|
||||||
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
||||||
act = tuned_cfg["actuators"][0]
|
act = tuned_cfg["actuators"][0]
|
||||||
if motor_params:
|
|
||||||
# Asymmetric gear, damping, deadzone, frictionloss as [pos, neg].
|
# Asymmetric gear, damping, deadzone, frictionloss as [pos, neg].
|
||||||
gear_pos = motor_params.get("actuator_gear_pos", 0.424)
|
act["gear"] = [
|
||||||
gear_neg = motor_params.get("actuator_gear_neg", 0.425)
|
round(params.get("actuator_gear_pos", 0.424), 6),
|
||||||
act["gear"] = [round(gear_pos, 6), round(gear_neg, 6)]
|
round(params.get("actuator_gear_neg", 0.425), 6),
|
||||||
|
]
|
||||||
damp_pos = motor_params.get("motor_damping_pos", 0.002)
|
act["damping"] = [
|
||||||
damp_neg = motor_params.get("motor_damping_neg", 0.015)
|
round(params.get("motor_damping_pos", 0.002), 6),
|
||||||
act["damping"] = [round(damp_pos, 6), round(damp_neg, 6)]
|
round(params.get("motor_damping_neg", 0.015), 6),
|
||||||
|
]
|
||||||
dz_pos = motor_params.get("motor_deadzone_pos", 0.141)
|
act["deadzone"] = [
|
||||||
dz_neg = motor_params.get("motor_deadzone_neg", 0.078)
|
round(params.get("motor_deadzone_pos", 0.141), 6),
|
||||||
act["deadzone"] = [round(dz_pos, 6), round(dz_neg, 6)]
|
round(params.get("motor_deadzone_neg", 0.078), 6),
|
||||||
|
]
|
||||||
fl_pos = motor_params.get("motor_frictionloss_pos", 0.057)
|
act["frictionloss"] = [
|
||||||
fl_neg = motor_params.get("motor_frictionloss_neg", 0.053)
|
round(params.get("motor_frictionloss_pos", 0.057), 6),
|
||||||
act["frictionloss"] = [round(fl_pos, 6), round(fl_neg, 6)]
|
round(params.get("motor_frictionloss_neg", 0.053), 6),
|
||||||
|
]
|
||||||
if "actuator_filter_tau" in motor_params:
|
|
||||||
act["filter_tau"] = round(motor_params["actuator_filter_tau"], 6)
|
|
||||||
if "viscous_quadratic" in motor_params:
|
|
||||||
act["viscous_quadratic"] = round(motor_params["viscous_quadratic"], 6)
|
|
||||||
if "back_emf_gain" in motor_params:
|
|
||||||
act["back_emf_gain"] = round(motor_params["back_emf_gain"], 6)
|
|
||||||
else:
|
|
||||||
if "actuator_gear" in params:
|
|
||||||
act["gear"] = round(params["actuator_gear"], 6)
|
|
||||||
if "actuator_filter_tau" in params:
|
if "actuator_filter_tau" in params:
|
||||||
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
||||||
if "motor_damping" in params:
|
|
||||||
act["damping"] = round(params["motor_damping"], 6)
|
# Stribeck friction and action bias.
|
||||||
if "motor_deadzone" in params:
|
if "stribeck_friction_boost" in params:
|
||||||
act["deadzone"] = round(params["motor_deadzone"], 6)
|
act["stribeck_friction_boost"] = round(params["stribeck_friction_boost"], 6)
|
||||||
|
if "stribeck_vel" in params:
|
||||||
|
act["stribeck_vel"] = round(params["stribeck_vel"], 6)
|
||||||
|
if "action_bias" in params:
|
||||||
|
act["action_bias"] = round(params["action_bias"], 6)
|
||||||
|
|
||||||
# ctrl_range from ctrl_limit parameter.
|
# ctrl_range from ctrl_limit parameter.
|
||||||
if "ctrl_limit" in params:
|
if "ctrl_limit" in params:
|
||||||
@@ -112,15 +110,10 @@ def export_tuned_files(
|
|||||||
if "motor_joint" not in tuned_cfg["joints"]:
|
if "motor_joint" not in tuned_cfg["joints"]:
|
||||||
tuned_cfg["joints"]["motor_joint"] = {}
|
tuned_cfg["joints"]["motor_joint"] = {}
|
||||||
mj = tuned_cfg["joints"]["motor_joint"]
|
mj = tuned_cfg["joints"]["motor_joint"]
|
||||||
if motor_params:
|
|
||||||
mj["armature"] = round(motor_params.get("motor_armature", 0.00277), 6)
|
|
||||||
# Frictionloss/damping = 0 in MuJoCo (motor model handles via qfrc_applied).
|
|
||||||
mj["frictionloss"] = 0.0
|
|
||||||
else:
|
|
||||||
if "motor_armature" in params:
|
if "motor_armature" in params:
|
||||||
mj["armature"] = round(params["motor_armature"], 6)
|
mj["armature"] = round(params["motor_armature"], 6)
|
||||||
if "motor_frictionloss" in params:
|
# Frictionloss/damping = 0 in MuJoCo (motor model handles via qfrc_applied).
|
||||||
mj["frictionloss"] = round(params["motor_frictionloss"], 6)
|
mj["frictionloss"] = 0.0
|
||||||
|
|
||||||
if "pendulum_joint" not in tuned_cfg["joints"]:
|
if "pendulum_joint" not in tuned_cfg["joints"]:
|
||||||
tuned_cfg["joints"]["pendulum_joint"] = {}
|
tuned_cfg["joints"]["pendulum_joint"] = {}
|
||||||
@@ -154,8 +147,6 @@ def main() -> None:
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.sysid.rollout import LOCKED_MOTOR_PARAMS
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Export tuned URDF + robot.yaml from sysid results."
|
description="Export tuned URDF + robot.yaml from sysid results."
|
||||||
)
|
)
|
||||||
@@ -183,7 +174,6 @@ def main() -> None:
|
|||||||
export_tuned_files(
|
export_tuned_files(
|
||||||
robot_path=args.robot_path,
|
robot_path=args.robot_path,
|
||||||
params=result["best_params"],
|
params=result["best_params"],
|
||||||
motor_params=result.get("motor_params", dict(LOCKED_MOTOR_PARAMS)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Motor-only system identification subpackage.
|
|
||||||
|
|
||||||
Identifies JGB37-520 DC motor dynamics (no pendulum, no limits)
|
|
||||||
so the MuJoCo simulation matches the real hardware response.
|
|
||||||
"""
|
|
||||||
@@ -1,364 +0,0 @@
|
|||||||
"""Capture a motor-only trajectory under random excitation (PRBS-style).
|
|
||||||
|
|
||||||
Connects to the ESP32 running the simplified sysid firmware (no pendulum,
|
|
||||||
no limits), sends random PWM commands, and records motor angle + velocity
|
|
||||||
at ~ 50 Hz.
|
|
||||||
|
|
||||||
Firmware serial protocol (115200 baud):
|
|
||||||
Commands: M<speed>\\n R\\n S\\n G\\n H\\n P\\n
|
|
||||||
State: S,<millis>,<encoder_count>,<rpm>,<applied_speed>,<enc_vel_cps>\\n
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m src.sysid.motor.capture --duration 20
|
|
||||||
python -m src.sysid.motor.capture --duration 30 --amplitude 200
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import structlog
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
log = structlog.get_logger()
|
|
||||||
|
|
||||||
# ── Default asset path ───────────────────────────────────────────────
|
|
||||||
_DEFAULT_ASSET = "assets/motor"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Serial protocol helpers ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_state_line(line: str) -> dict[str, Any] | None:
|
|
||||||
"""Parse an ``S,…`` state line from the motor sysid firmware.
|
|
||||||
|
|
||||||
Format: S,<millis>,<encoder_count>,<rpm>,<applied_speed>,<enc_vel_cps>
|
|
||||||
"""
|
|
||||||
if not line.startswith("S,"):
|
|
||||||
return None
|
|
||||||
parts = line.split(",")
|
|
||||||
if len(parts) < 6:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
"timestamp_ms": int(parts[1]),
|
|
||||||
"encoder_count": int(parts[2]),
|
|
||||||
"rpm": float(parts[3]),
|
|
||||||
"applied_speed": int(parts[4]),
|
|
||||||
"enc_vel_cps": float(parts[5]),
|
|
||||||
}
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Background serial reader ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _SerialReader:
|
|
||||||
"""Minimal background reader for the ESP32 serial stream."""
|
|
||||||
|
|
||||||
def __init__(self, port: str, baud: int = 115200):
|
|
||||||
import serial as _serial
|
|
||||||
|
|
||||||
self._serial_mod = _serial
|
|
||||||
self.ser = _serial.Serial(port, baud, timeout=0.05)
|
|
||||||
time.sleep(2) # Wait for ESP32 boot
|
|
||||||
self.ser.reset_input_buffer()
|
|
||||||
|
|
||||||
self._latest: dict[str, Any] = {}
|
|
||||||
self._seq: int = 0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._cond = threading.Condition(self._lock)
|
|
||||||
self._running = True
|
|
||||||
|
|
||||||
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
|
|
||||||
def _reader_loop(self) -> None:
|
|
||||||
while self._running:
|
|
||||||
try:
|
|
||||||
if self.ser.in_waiting:
|
|
||||||
line = (
|
|
||||||
self.ser.readline()
|
|
||||||
.decode("utf-8", errors="ignore")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
parsed = _parse_state_line(line)
|
|
||||||
if parsed is not None:
|
|
||||||
with self._cond:
|
|
||||||
self._latest = parsed
|
|
||||||
self._seq += 1
|
|
||||||
self._cond.notify_all()
|
|
||||||
elif line and not line.startswith("S,"):
|
|
||||||
# Log non-state lines (READY, PONG, WARN, etc.)
|
|
||||||
log.debug("serial_info", line=line)
|
|
||||||
else:
|
|
||||||
time.sleep(0.001)
|
|
||||||
except (OSError, self._serial_mod.SerialException):
|
|
||||||
log.critical("serial_lost")
|
|
||||||
break
|
|
||||||
|
|
||||||
def send(self, cmd: str) -> None:
|
|
||||||
try:
|
|
||||||
self.ser.write(f"{cmd}\n".encode())
|
|
||||||
except (OSError, self._serial_mod.SerialException):
|
|
||||||
log.critical("serial_send_failed", cmd=cmd)
|
|
||||||
|
|
||||||
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
|
|
||||||
"""Wait until a *new* state line arrives, then return it."""
|
|
||||||
with self._cond:
|
|
||||||
seq_before = self._seq
|
|
||||||
if not self._cond.wait_for(
|
|
||||||
lambda: self._seq > seq_before, timeout=timeout
|
|
||||||
):
|
|
||||||
return {}
|
|
||||||
return dict(self._latest)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self._running = False
|
|
||||||
self.send("H")
|
|
||||||
self.send("M0")
|
|
||||||
time.sleep(0.1)
|
|
||||||
self._thread.join(timeout=1.0)
|
|
||||||
self.ser.close()
|
|
||||||
|
|
||||||
|
|
||||||
# ── PRBS excitation signal ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _PRBSExcitation:
|
|
||||||
"""Random hold-value excitation with configurable amplitude and hold time."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
amplitude: int = 200,
|
|
||||||
hold_min_ms: int = 50,
|
|
||||||
hold_max_ms: int = 400,
|
|
||||||
):
|
|
||||||
self.amplitude = amplitude
|
|
||||||
self.hold_min_ms = hold_min_ms
|
|
||||||
self.hold_max_ms = hold_max_ms
|
|
||||||
self._current: int = 0
|
|
||||||
self._switch_time: float = 0.0
|
|
||||||
self._new_value()
|
|
||||||
|
|
||||||
def _new_value(self) -> None:
|
|
||||||
self._current = random.randint(-self.amplitude, self.amplitude)
|
|
||||||
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
|
|
||||||
self._switch_time = time.monotonic() + hold_ms / 1000.0
|
|
||||||
|
|
||||||
def __call__(self) -> int:
|
|
||||||
if time.monotonic() >= self._switch_time:
|
|
||||||
self._new_value()
|
|
||||||
return self._current
|
|
||||||
|
|
||||||
|
|
||||||
# ── Main capture loop ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def capture(
|
|
||||||
asset_path: str | Path = _DEFAULT_ASSET,
|
|
||||||
port: str = "/dev/cu.usbserial-0001",
|
|
||||||
baud: int = 115200,
|
|
||||||
duration: float = 20.0,
|
|
||||||
amplitude: int = 200,
|
|
||||||
hold_min_ms: int = 50,
|
|
||||||
hold_max_ms: int = 400,
|
|
||||||
dt: float = 0.02,
|
|
||||||
) -> Path:
|
|
||||||
"""Run motor-only capture and return the path to the saved .npz file.
|
|
||||||
|
|
||||||
Stream-driven: blocks on each firmware state line (~50 Hz),
|
|
||||||
sends next motor command immediately, records both.
|
|
||||||
No time.sleep pacing — locked to firmware clock.
|
|
||||||
|
|
||||||
The recording stores:
|
|
||||||
- time: wall-clock seconds since start
|
|
||||||
- action: normalised action = applied_speed / 255
|
|
||||||
- motor_angle: shaft angle in radians (from encoder)
|
|
||||||
- motor_vel: shaft velocity in rad/s (from encoder velocity)
|
|
||||||
"""
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
|
|
||||||
# Load hardware config for encoder conversion.
|
|
||||||
hw_yaml = asset_path / "hardware.yaml"
|
|
||||||
if not hw_yaml.exists():
|
|
||||||
raise FileNotFoundError(f"hardware.yaml not found in {asset_path}")
|
|
||||||
raw_hw = yaml.safe_load(hw_yaml.read_text())
|
|
||||||
ppr = raw_hw.get("encoder", {}).get("ppr", 11)
|
|
||||||
gear_ratio = raw_hw.get("encoder", {}).get("gear_ratio", 30.0)
|
|
||||||
counts_per_rev: float = ppr * gear_ratio * 4.0
|
|
||||||
max_pwm = raw_hw.get("motor", {}).get("max_pwm", 255)
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"hardware_config",
|
|
||||||
ppr=ppr,
|
|
||||||
gear_ratio=gear_ratio,
|
|
||||||
counts_per_rev=counts_per_rev,
|
|
||||||
max_pwm=max_pwm,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Connect.
|
|
||||||
reader = _SerialReader(port, baud)
|
|
||||||
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
|
|
||||||
|
|
||||||
# Prepare recording buffers.
|
|
||||||
max_samples = int(duration / dt) + 500
|
|
||||||
rec_time = np.zeros(max_samples, dtype=np.float64)
|
|
||||||
rec_action = np.zeros(max_samples, dtype=np.float64)
|
|
||||||
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
|
|
||||||
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
|
|
||||||
|
|
||||||
# Reset encoder to zero.
|
|
||||||
reader.send("R")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Start streaming.
|
|
||||||
reader.send("G")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"capture_starting",
|
|
||||||
port=port,
|
|
||||||
duration=duration,
|
|
||||||
amplitude=amplitude,
|
|
||||||
hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
|
|
||||||
)
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
pwm = 0
|
|
||||||
last_esp_ms = -1
|
|
||||||
t0 = time.monotonic()
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
state = reader.read_blocking(timeout=0.1)
|
|
||||||
if not state:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Deduplicate by firmware timestamp.
|
|
||||||
esp_ms = state.get("timestamp_ms", 0)
|
|
||||||
if esp_ms == last_esp_ms:
|
|
||||||
continue
|
|
||||||
last_esp_ms = esp_ms
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - t0
|
|
||||||
if elapsed >= duration:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Generate next excitation PWM.
|
|
||||||
pwm = excitation()
|
|
||||||
|
|
||||||
# Send command.
|
|
||||||
reader.send(f"M{pwm}")
|
|
||||||
|
|
||||||
# Convert encoder to angle/velocity.
|
|
||||||
enc = state.get("encoder_count", 0)
|
|
||||||
motor_angle = enc / counts_per_rev * 2.0 * math.pi
|
|
||||||
motor_vel = (
|
|
||||||
state.get("enc_vel_cps", 0.0) / counts_per_rev * 2.0 * math.pi
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use firmware-applied speed for the action.
|
|
||||||
applied = state.get("applied_speed", 0)
|
|
||||||
action_norm = applied / 255.0
|
|
||||||
|
|
||||||
if idx < max_samples:
|
|
||||||
rec_time[idx] = elapsed
|
|
||||||
rec_action[idx] = action_norm
|
|
||||||
rec_motor_angle[idx] = motor_angle
|
|
||||||
rec_motor_vel[idx] = motor_vel
|
|
||||||
idx += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if idx % 50 == 0:
|
|
||||||
log.info(
|
|
||||||
"capture_progress",
|
|
||||||
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
|
|
||||||
samples=idx,
|
|
||||||
pwm=pwm,
|
|
||||||
angle_deg=f"{math.degrees(motor_angle):.1f}",
|
|
||||||
vel_rps=f"{motor_vel / (2 * math.pi):.1f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
reader.send("M0")
|
|
||||||
reader.close()
|
|
||||||
|
|
||||||
# Trim.
|
|
||||||
rec_time = rec_time[:idx]
|
|
||||||
rec_action = rec_action[:idx]
|
|
||||||
rec_motor_angle = rec_motor_angle[:idx]
|
|
||||||
rec_motor_vel = rec_motor_vel[:idx]
|
|
||||||
|
|
||||||
# Save.
|
|
||||||
recordings_dir = asset_path / "recordings"
|
|
||||||
recordings_dir.mkdir(exist_ok=True)
|
|
||||||
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
out_path = recordings_dir / f"motor_capture_{stamp}.npz"
|
|
||||||
np.savez_compressed(
|
|
||||||
out_path,
|
|
||||||
time=rec_time,
|
|
||||||
action=rec_action,
|
|
||||||
motor_angle=rec_motor_angle,
|
|
||||||
motor_vel=rec_motor_vel,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"capture_saved",
|
|
||||||
path=str(out_path),
|
|
||||||
samples=idx,
|
|
||||||
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
|
|
||||||
)
|
|
||||||
return out_path
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Capture motor-only trajectory for system identification."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--asset-path", type=str, default=_DEFAULT_ASSET,
|
|
||||||
help="Path to motor asset directory (contains hardware.yaml)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=str, default="/dev/cu.usbserial-0001",
|
|
||||||
help="Serial port for ESP32",
|
|
||||||
)
|
|
||||||
parser.add_argument("--baud", type=int, default=115200)
|
|
||||||
parser.add_argument("--duration", type=float, default=20.0, help="Capture duration (s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--amplitude", type=int, default=200,
|
|
||||||
help="Max PWM magnitude (0–255, firmware allows full range)",
|
|
||||||
)
|
|
||||||
parser.add_argument("--hold-min-ms", type=int, default=50, help="PRBS min hold (ms)")
|
|
||||||
parser.add_argument("--hold-max-ms", type=int, default=400, help="PRBS max hold (ms)")
|
|
||||||
parser.add_argument("--dt", type=float, default=0.02, help="Nominal sample period (s)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
capture(
|
|
||||||
asset_path=args.asset_path,
|
|
||||||
port=args.port,
|
|
||||||
baud=args.baud,
|
|
||||||
duration=args.duration,
|
|
||||||
amplitude=args.amplitude,
|
|
||||||
hold_min_ms=args.hold_min_ms,
|
|
||||||
hold_max_ms=args.hold_max_ms,
|
|
||||||
dt=args.dt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
"""CMA-ES optimiser — fit motor simulation parameters to a real recording.
|
|
||||||
|
|
||||||
Motor-only version: minimises trajectory-matching cost between MuJoCo
|
|
||||||
rollout and recorded motor angle + velocity.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m src.sysid.motor.optimize \
|
|
||||||
--recording assets/motor/recordings/motor_capture_YYYYMMDD_HHMMSS.npz
|
|
||||||
|
|
||||||
# Quick test:
|
|
||||||
python -m src.sysid.motor.optimize --recording <file>.npz \
|
|
||||||
--max-generations 20 --population-size 10
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
from src.sysid.motor.preprocess import recompute_velocity
|
|
||||||
from src.sysid.motor.rollout import (
|
|
||||||
MOTOR_PARAMS,
|
|
||||||
ParamSpec,
|
|
||||||
bounds_arrays,
|
|
||||||
defaults_vector,
|
|
||||||
params_to_dict,
|
|
||||||
rollout,
|
|
||||||
windowed_rollout,
|
|
||||||
)
|
|
||||||
|
|
||||||
log = structlog.get_logger()
|
|
||||||
|
|
||||||
_DEFAULT_ASSET = "assets/motor"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cost function ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_trajectory_cost(
|
|
||||||
sim: dict[str, np.ndarray],
|
|
||||||
recording: dict[str, np.ndarray],
|
|
||||||
pos_weight: float = 1.0,
|
|
||||||
vel_weight: float = 0.5,
|
|
||||||
acc_weight: float = 0.0,
|
|
||||||
dt: float = 0.02,
|
|
||||||
) -> float:
|
|
||||||
"""Weighted MSE between simulated and real motor trajectories.
|
|
||||||
|
|
||||||
Motor-only: angle, velocity, and optionally acceleration.
|
|
||||||
Acceleration is computed as finite-difference of velocity.
|
|
||||||
"""
|
|
||||||
angle_err = sim["motor_angle"] - recording["motor_angle"]
|
|
||||||
vel_err = sim["motor_vel"] - recording["motor_vel"]
|
|
||||||
|
|
||||||
# Reject NaN results (unstable simulation).
|
|
||||||
if np.any(~np.isfinite(angle_err)) or np.any(~np.isfinite(vel_err)):
|
|
||||||
return 1e6
|
|
||||||
|
|
||||||
cost = float(
|
|
||||||
pos_weight * np.mean(angle_err**2)
|
|
||||||
+ vel_weight * np.mean(vel_err**2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Optional acceleration term — penalises wrong dynamics (d(vel)/dt).
|
|
||||||
if acc_weight > 0 and len(vel_err) > 2:
|
|
||||||
sim_acc = np.diff(sim["motor_vel"]) / dt
|
|
||||||
real_acc = np.diff(recording["motor_vel"]) / dt
|
|
||||||
acc_err = sim_acc - real_acc
|
|
||||||
if np.any(~np.isfinite(acc_err)):
|
|
||||||
return 1e6
|
|
||||||
# Normalise by typical acceleration scale (~50 rad/s²) to keep
|
|
||||||
# the weight intuitive relative to vel/pos terms.
|
|
||||||
cost += acc_weight * np.mean(acc_err**2) / (50.0**2)
|
|
||||||
|
|
||||||
return cost
|
|
||||||
|
|
||||||
|
|
||||||
def cost_function(
|
|
||||||
params_vec: np.ndarray,
|
|
||||||
recording: dict[str, np.ndarray],
|
|
||||||
asset_path: Path,
|
|
||||||
specs: list[ParamSpec],
|
|
||||||
sim_dt: float = 0.002,
|
|
||||||
substeps: int = 10,
|
|
||||||
pos_weight: float = 1.0,
|
|
||||||
vel_weight: float = 0.5,
|
|
||||||
acc_weight: float = 0.0,
|
|
||||||
window_duration: float = 0.5,
|
|
||||||
) -> float:
|
|
||||||
"""Compute trajectory-matching cost for a candidate parameter vector.
|
|
||||||
|
|
||||||
Uses multiple-shooting (windowed rollout) by default.
|
|
||||||
"""
|
|
||||||
params = params_to_dict(params_vec, specs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if window_duration > 0:
|
|
||||||
sim = windowed_rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=params,
|
|
||||||
recording=recording,
|
|
||||||
window_duration=window_duration,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sim = rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=params,
|
|
||||||
actions=recording["action"],
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
log.warning("rollout_failed", error=str(exc))
|
|
||||||
return 1e6
|
|
||||||
|
|
||||||
return _compute_trajectory_cost(
|
|
||||||
sim, recording, pos_weight, vel_weight, acc_weight,
|
|
||||||
dt=np.mean(np.diff(recording["time"])) if len(recording["time"]) > 1 else 0.02,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── CMA-ES optimisation loop ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def optimize(
|
|
||||||
asset_path: str | Path = _DEFAULT_ASSET,
|
|
||||||
recording_path: str | Path = "",
|
|
||||||
specs: list[ParamSpec] | None = None,
|
|
||||||
sigma0: float = 0.3,
|
|
||||||
population_size: int = 30,
|
|
||||||
max_generations: int = 300,
|
|
||||||
sim_dt: float = 0.002,
|
|
||||||
substeps: int = 10,
|
|
||||||
pos_weight: float = 1.0,
|
|
||||||
vel_weight: float = 0.5,
|
|
||||||
acc_weight: float = 0.0,
|
|
||||||
window_duration: float = 0.5,
|
|
||||||
seed: int = 42,
|
|
||||||
preprocess_vel: bool = True,
|
|
||||||
sg_window: int = 7,
|
|
||||||
sg_polyorder: int = 3,
|
|
||||||
) -> dict:
|
|
||||||
"""Run CMA-ES optimisation and return results dict."""
|
|
||||||
from cmaes import CMA
|
|
||||||
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
recording_path = Path(recording_path).resolve()
|
|
||||||
|
|
||||||
if specs is None:
|
|
||||||
specs = MOTOR_PARAMS
|
|
||||||
|
|
||||||
# Load recording.
|
|
||||||
recording = dict(np.load(recording_path))
|
|
||||||
n_samples = len(recording["time"])
|
|
||||||
duration = recording["time"][-1] - recording["time"][0]
|
|
||||||
n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
|
|
||||||
log.info(
|
|
||||||
"recording_loaded",
|
|
||||||
path=str(recording_path),
|
|
||||||
samples=n_samples,
|
|
||||||
duration=f"{duration:.1f}s",
|
|
||||||
n_windows=n_windows,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Preprocess velocity: replace noisy firmware finite-difference with
|
|
||||||
# smooth Savitzky-Golay derivative of the angle signal.
|
|
||||||
if preprocess_vel:
|
|
||||||
recording = recompute_velocity(
|
|
||||||
recording,
|
|
||||||
window_length=sg_window,
|
|
||||||
polyorder=sg_polyorder,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalise to [0, 1] for CMA-ES.
|
|
||||||
lo, hi = bounds_arrays(specs)
|
|
||||||
x0 = defaults_vector(specs)
|
|
||||||
span = hi - lo
|
|
||||||
span[span == 0] = 1.0
|
|
||||||
|
|
||||||
def to_normed(x: np.ndarray) -> np.ndarray:
|
|
||||||
return (x - lo) / span
|
|
||||||
|
|
||||||
def from_normed(x_n: np.ndarray) -> np.ndarray:
|
|
||||||
return x_n * span + lo
|
|
||||||
|
|
||||||
x0_normed = to_normed(x0)
|
|
||||||
bounds_normed = np.column_stack(
|
|
||||||
[np.zeros(len(specs)), np.ones(len(specs))]
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = CMA(
|
|
||||||
mean=x0_normed,
|
|
||||||
sigma=sigma0,
|
|
||||||
bounds=bounds_normed,
|
|
||||||
population_size=population_size,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
best_cost = float("inf")
|
|
||||||
best_params_vec = x0.copy()
|
|
||||||
history: list[tuple[int, float]] = []
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"cmaes_starting",
|
|
||||||
n_params=len(specs),
|
|
||||||
population=population_size,
|
|
||||||
max_gens=max_generations,
|
|
||||||
sigma0=sigma0,
|
|
||||||
)
|
|
||||||
|
|
||||||
t0 = time.monotonic()
|
|
||||||
|
|
||||||
for gen in range(max_generations):
|
|
||||||
solutions = []
|
|
||||||
for _ in range(optimizer.population_size):
|
|
||||||
x_normed = optimizer.ask()
|
|
||||||
x_natural = from_normed(x_normed)
|
|
||||||
x_natural = np.clip(x_natural, lo, hi)
|
|
||||||
|
|
||||||
c = cost_function(
|
|
||||||
x_natural,
|
|
||||||
recording,
|
|
||||||
asset_path,
|
|
||||||
specs,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
pos_weight=pos_weight,
|
|
||||||
vel_weight=vel_weight,
|
|
||||||
acc_weight=acc_weight,
|
|
||||||
window_duration=window_duration,
|
|
||||||
)
|
|
||||||
solutions.append((x_normed, c))
|
|
||||||
|
|
||||||
if c < best_cost:
|
|
||||||
best_cost = c
|
|
||||||
best_params_vec = x_natural.copy()
|
|
||||||
|
|
||||||
optimizer.tell(solutions)
|
|
||||||
history.append((gen, best_cost))
|
|
||||||
|
|
||||||
elapsed = time.monotonic() - t0
|
|
||||||
if gen % 5 == 0 or gen == max_generations - 1:
|
|
||||||
log.info(
|
|
||||||
"cmaes_generation",
|
|
||||||
gen=gen,
|
|
||||||
best_cost=f"{best_cost:.6f}",
|
|
||||||
elapsed=f"{elapsed:.1f}s",
|
|
||||||
gen_best=f"{min(c for _, c in solutions):.6f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
total_time = time.monotonic() - t0
|
|
||||||
best_params = params_to_dict(best_params_vec, specs)
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"cmaes_finished",
|
|
||||||
best_cost=f"{best_cost:.6f}",
|
|
||||||
total_time=f"{total_time:.1f}s",
|
|
||||||
evaluations=max_generations * population_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log parameter comparison.
|
|
||||||
defaults = params_to_dict(defaults_vector(specs), specs)
|
|
||||||
for name in best_params:
|
|
||||||
d = defaults[name]
|
|
||||||
b = best_params[name]
|
|
||||||
change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
|
|
||||||
log.info(
|
|
||||||
"param_result",
|
|
||||||
name=name,
|
|
||||||
default=f"{d:.6g}",
|
|
||||||
tuned=f"{b:.6g}",
|
|
||||||
change=f"{change_pct:+.1f}%",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"best_params": best_params,
|
|
||||||
"best_cost": best_cost,
|
|
||||||
"history": history,
|
|
||||||
"recording": str(recording_path),
|
|
||||||
"param_names": [s.name for s in specs],
|
|
||||||
"defaults": {s.name: s.default for s in specs},
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Fit motor simulation parameters to a real recording (CMA-ES)."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--asset-path", type=str, default=_DEFAULT_ASSET,
|
|
||||||
help="Path to motor asset directory",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--recording", type=str, required=True,
|
|
||||||
help="Path to .npz recording file",
|
|
||||||
)
|
|
||||||
parser.add_argument("--sigma0", type=float, default=0.3)
|
|
||||||
parser.add_argument("--population-size", type=int, default=30)
|
|
||||||
parser.add_argument("--max-generations", type=int, default=300)
|
|
||||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
|
||||||
parser.add_argument("--substeps", type=int, default=10)
|
|
||||||
parser.add_argument("--pos-weight", type=float, default=1.0)
|
|
||||||
parser.add_argument("--vel-weight", type=float, default=0.5)
|
|
||||||
parser.add_argument("--acc-weight", type=float, default=0.0,
|
|
||||||
help="Weight for acceleration matching (0=off, try 0.1-0.5)")
|
|
||||||
parser.add_argument("--window-duration", type=float, default=0.5)
|
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-preprocess-vel", action="store_true",
|
|
||||||
help="Disable Savitzky-Golay velocity preprocessing",
|
|
||||||
)
|
|
||||||
parser.add_argument("--sg-window", type=int, default=7,
|
|
||||||
help="Savitzky-Golay window length (odd, default 7 = 140ms)")
|
|
||||||
parser.add_argument("--sg-polyorder", type=int, default=3,
|
|
||||||
help="Savitzky-Golay polynomial order (default 3 = cubic)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
result = optimize(
|
|
||||||
asset_path=args.asset_path,
|
|
||||||
recording_path=args.recording,
|
|
||||||
sigma0=args.sigma0,
|
|
||||||
population_size=args.population_size,
|
|
||||||
max_generations=args.max_generations,
|
|
||||||
sim_dt=args.sim_dt,
|
|
||||||
substeps=args.substeps,
|
|
||||||
pos_weight=args.pos_weight,
|
|
||||||
vel_weight=args.vel_weight,
|
|
||||||
acc_weight=args.acc_weight,
|
|
||||||
window_duration=args.window_duration,
|
|
||||||
seed=args.seed,
|
|
||||||
preprocess_vel=not args.no_preprocess_vel,
|
|
||||||
sg_window=args.sg_window,
|
|
||||||
sg_polyorder=args.sg_polyorder,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save results JSON.
|
|
||||||
asset_path = Path(args.asset_path).resolve()
|
|
||||||
result_path = asset_path / "motor_sysid_result.json"
|
|
||||||
result_json = {k: v for k, v in result.items() if k != "history"}
|
|
||||||
result_json["history_summary"] = {
|
|
||||||
"first_cost": result["history"][0][1] if result["history"] else None,
|
|
||||||
"final_cost": result["history"][-1][1] if result["history"] else None,
|
|
||||||
"generations": len(result["history"]),
|
|
||||||
}
|
|
||||||
result_path.write_text(json.dumps(result_json, indent=2, default=str))
|
|
||||||
log.info("results_saved", path=str(result_path))
|
|
||||||
|
|
||||||
# Export tuned files.
|
|
||||||
from src.sysid.motor.export import export_tuned_files
|
|
||||||
|
|
||||||
export_tuned_files(asset_path=args.asset_path, params=result["best_params"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
"""Recording preprocessing — clean velocity estimation from angle data.
|
|
||||||
|
|
||||||
The ESP32 firmware computes velocity as a raw finite-difference of encoder
|
|
||||||
counts at 50 Hz. With a 1320 CPR encoder that gives ~0.24 rad/s of
|
|
||||||
quantisation noise per count. This module replaces the noisy firmware
|
|
||||||
velocity with a smooth differentiation of the (much cleaner) angle signal.
|
|
||||||
|
|
||||||
Method: Savitzky-Golay filter applied to the angle signal, then
|
|
||||||
differentiated analytically. Zero phase lag, preserves transients well.
|
|
||||||
|
|
||||||
This is standard practice in robotics sysid — see e.g. MATLAB System ID
|
|
||||||
Toolbox, Drake's trajectory processing, or ETH's ANYmal sysid pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from scipy.signal import savgol_filter
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
log = structlog.get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def recompute_velocity(
|
|
||||||
recording: dict[str, np.ndarray],
|
|
||||||
window_length: int = 7,
|
|
||||||
polyorder: int = 3,
|
|
||||||
deriv: int = 1,
|
|
||||||
) -> dict[str, np.ndarray]:
|
|
||||||
"""Recompute motor_vel from motor_angle using Savitzky-Golay differentiation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : dict with at least 'time', 'motor_angle', 'motor_vel' keys.
|
|
||||||
window_length : SG filter window (must be odd, > polyorder).
|
|
||||||
7 samples at 50 Hz = 140ms window — good balance of smoothness
|
|
||||||
and responsiveness. Captures dynamics up to ~7 Hz.
|
|
||||||
polyorder : Polynomial order for the SG filter (3 = cubic).
|
|
||||||
deriv : Derivative order (1 = first derivative = velocity).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
New recording dict with 'motor_vel' replaced and 'motor_vel_raw' added.
|
|
||||||
"""
|
|
||||||
rec = dict(recording) # shallow copy
|
|
||||||
|
|
||||||
times = rec["time"]
|
|
||||||
angles = rec["motor_angle"]
|
|
||||||
dt = np.mean(np.diff(times))
|
|
||||||
|
|
||||||
# Keep original for diagnostics.
|
|
||||||
rec["motor_vel_raw"] = rec["motor_vel"].copy()
|
|
||||||
|
|
||||||
# Savitzky-Golay derivative: fits a polynomial to each window,
|
|
||||||
# then takes the analytical derivative → smooth, zero phase lag.
|
|
||||||
vel_sg = savgol_filter(
|
|
||||||
angles,
|
|
||||||
window_length=window_length,
|
|
||||||
polyorder=polyorder,
|
|
||||||
deriv=deriv,
|
|
||||||
delta=dt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute stats for logging.
|
|
||||||
raw_vel = rec["motor_vel_raw"]
|
|
||||||
noise_estimate = np.std(raw_vel - vel_sg)
|
|
||||||
max_diff = np.max(np.abs(raw_vel - vel_sg))
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"velocity_recomputed",
|
|
||||||
method="savgol",
|
|
||||||
window=window_length,
|
|
||||||
polyorder=polyorder,
|
|
||||||
dt=f"{dt*1000:.1f}ms",
|
|
||||||
noise_std=f"{noise_estimate:.3f} rad/s",
|
|
||||||
max_diff=f"{max_diff:.3f} rad/s",
|
|
||||||
raw_rms=f"{np.sqrt(np.mean(raw_vel**2)):.3f}",
|
|
||||||
sg_rms=f"{np.sqrt(np.mean(vel_sg**2)):.3f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
rec["motor_vel"] = vel_sg
|
|
||||||
return rec
|
|
||||||
|
|
||||||
|
|
||||||
def smooth_velocity(
|
|
||||||
recording: dict[str, np.ndarray],
|
|
||||||
cutoff_hz: float = 10.0,
|
|
||||||
) -> dict[str, np.ndarray]:
|
|
||||||
"""Alternative: apply zero-phase Butterworth low-pass to motor_vel.
|
|
||||||
|
|
||||||
Simpler than SG derivative but introduces slight edge effects.
|
|
||||||
"""
|
|
||||||
from scipy.signal import butter, filtfilt
|
|
||||||
|
|
||||||
rec = dict(recording)
|
|
||||||
rec["motor_vel_raw"] = rec["motor_vel"].copy()
|
|
||||||
|
|
||||||
dt = np.mean(np.diff(rec["time"]))
|
|
||||||
fs = 1.0 / dt
|
|
||||||
nyq = fs / 2.0
|
|
||||||
norm_cutoff = min(cutoff_hz / nyq, 0.99)
|
|
||||||
|
|
||||||
b, a = butter(2, norm_cutoff, btype="low")
|
|
||||||
rec["motor_vel"] = filtfilt(b, a, rec["motor_vel"])
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"velocity_smoothed",
|
|
||||||
method="butterworth",
|
|
||||||
cutoff_hz=cutoff_hz,
|
|
||||||
fs=fs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return rec
|
|
||||||
@@ -1,460 +0,0 @@
|
|||||||
"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
|
|
||||||
|
|
||||||
Motor-only version: single hinge joint, no pendulum.
|
|
||||||
Given a parameter vector and recorded actions, builds a MuJoCo model
|
|
||||||
with overridden dynamics, replays the actions, and returns the simulated
|
|
||||||
motor angle + velocity for comparison with the real recording.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import mujoco
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tunable parameter specification ──────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ParamSpec:
|
|
||||||
"""Specification for a single tunable parameter."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
default: float
|
|
||||||
lower: float
|
|
||||||
upper: float
|
|
||||||
log_scale: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# Motor-only parameters to identify.
|
|
||||||
# These capture the full transfer function: PWM → shaft angle/velocity.
|
|
||||||
#
|
|
||||||
# Asymmetric parameters (pos/neg suffix) capture real-world differences
|
|
||||||
# between CW and CCW rotation caused by gear mesh, brush contact,
|
|
||||||
# and H-bridge asymmetry.
|
|
||||||
MOTOR_PARAMS: list[ParamSpec] = [
|
|
||||||
# ── Actuator ─────────────────────────────────────────────────
|
|
||||||
# gear_pos/neg: effective torque per unit ctrl, split by direction.
|
|
||||||
# Real motors + L298N often have different drive strength per direction.
|
|
||||||
ParamSpec("actuator_gear_pos", 0.064, 0.005, 0.5, log_scale=True),
|
|
||||||
ParamSpec("actuator_gear_neg", 0.064, 0.005, 0.5, log_scale=True),
|
|
||||||
# filter_tau: first-order electrical/driver time constant (s).
|
|
||||||
# Lower bound 1ms — L298N PWM switching is very fast.
|
|
||||||
ParamSpec("actuator_filter_tau", 0.03, 0.001, 0.20),
|
|
||||||
# ── Joint dynamics ───────────────────────────────────────────
|
|
||||||
# damping_pos/neg: viscous friction (back-EMF), split by direction.
|
|
||||||
ParamSpec("motor_damping_pos", 0.003, 1e-5, 0.1, log_scale=True),
|
|
||||||
ParamSpec("motor_damping_neg", 0.003, 1e-5, 0.1, log_scale=True),
|
|
||||||
# armature: reflected rotor inertia (kg·m²).
|
|
||||||
ParamSpec("motor_armature", 0.0001, 1e-6, 0.01, log_scale=True),
|
|
||||||
# frictionloss_pos/neg: Coulomb friction, split by velocity direction.
|
|
||||||
ParamSpec("motor_frictionloss_pos", 0.03, 0.001, 0.2, log_scale=True),
|
|
||||||
ParamSpec("motor_frictionloss_neg", 0.03, 0.001, 0.2, log_scale=True),
|
|
||||||
# ── Nonlinear dynamics ───────────────────────────────────────
|
|
||||||
# viscous_quadratic: velocity-squared drag term (N·m·s²/rad²).
|
|
||||||
# Captures nonlinear friction that increases at high speed (air drag,
|
|
||||||
# grease viscosity, etc.). Opposes motion.
|
|
||||||
ParamSpec("viscous_quadratic", 0.0, 0.0, 0.005),
|
|
||||||
# back_emf_gain: torque reduction proportional to |vel × ctrl|.
|
|
||||||
# Models the back-EMF effect: at high speed the motor produces less
|
|
||||||
# torque because the voltage drop across the armature is smaller.
|
|
||||||
ParamSpec("back_emf_gain", 0.0, 0.0, 0.05),
|
|
||||||
# stribeck_vel: characteristic velocity below which Coulomb friction
|
|
||||||
# is boosted (Stribeck effect). 0 = standard Coulomb only.
|
|
||||||
ParamSpec("stribeck_friction_boost", 0.0, 0.0, 0.15),
|
|
||||||
ParamSpec("stribeck_vel", 2.0, 0.1, 8.0),
|
|
||||||
# ── Rotor load ───────────────────────────────────────────────
|
|
||||||
ParamSpec("rotor_mass", 0.012, 0.002, 0.05, log_scale=True),
|
|
||||||
# ── Hardware realism ─────────────────────────────────────────
|
|
||||||
# deadzone_pos/neg: minimum |action| per direction.
|
|
||||||
ParamSpec("motor_deadzone_pos", 0.08, 0.0, 0.30),
|
|
||||||
ParamSpec("motor_deadzone_neg", 0.08, 0.0, 0.30),
|
|
||||||
# action_bias: constant offset added to ctrl (H-bridge asymmetry).
|
|
||||||
ParamSpec("action_bias", 0.0, -0.10, 0.10),
|
|
||||||
# ── Gearbox backlash ─────────────────────────────────────────
|
|
||||||
# backlash_halfwidth: half the angular deadband (rad) in the gearbox.
|
|
||||||
# When the motor reverses, the shaft doesn't move until the backlash
|
|
||||||
# gap is taken up. Typical for 30:1 plastic/metal spur gears.
|
|
||||||
ParamSpec("gearbox_backlash", 0.0, 0.0, 0.15),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def params_to_dict(
|
|
||||||
values: np.ndarray, specs: list[ParamSpec] | None = None
|
|
||||||
) -> dict[str, float]:
|
|
||||||
if specs is None:
|
|
||||||
specs = MOTOR_PARAMS
|
|
||||||
return {s.name: float(values[i]) for i, s in enumerate(specs)}
|
|
||||||
|
|
||||||
|
|
||||||
def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
|
|
||||||
if specs is None:
|
|
||||||
specs = MOTOR_PARAMS
|
|
||||||
return np.array([s.default for s in specs], dtype=np.float64)
|
|
||||||
|
|
||||||
|
|
||||||
def bounds_arrays(
|
|
||||||
specs: list[ParamSpec] | None = None,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
if specs is None:
|
|
||||||
specs = MOTOR_PARAMS
|
|
||||||
lo = np.array([s.lower for s in specs], dtype=np.float64)
|
|
||||||
hi = np.array([s.upper for s in specs], dtype=np.float64)
|
|
||||||
return lo, hi
|
|
||||||
|
|
||||||
|
|
||||||
# ── MuJoCo model building with parameter overrides ──────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model(
|
|
||||||
asset_path: Path,
|
|
||||||
params: dict[str, float],
|
|
||||||
) -> mujoco.MjModel:
|
|
||||||
"""Build a MuJoCo model from motor.xml with parameter overrides.
|
|
||||||
|
|
||||||
Parses the MJCF, patches actuator/joint/body parameters, reloads.
|
|
||||||
"""
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
robot_cfg = yaml.safe_load((asset_path / "robot.yaml").read_text())
|
|
||||||
mjcf_path = asset_path / robot_cfg["mjcf"]
|
|
||||||
|
|
||||||
tree = ET.parse(str(mjcf_path))
|
|
||||||
root = tree.getroot()
|
|
||||||
|
|
||||||
# ── Actuator overrides ───────────────────────────────────────
|
|
||||||
# Use average of pos/neg gear for MuJoCo (asymmetry handled in ctrl).
|
|
||||||
gear_pos = params.get("actuator_gear_pos", params.get("actuator_gear", 0.064))
|
|
||||||
gear_neg = params.get("actuator_gear_neg", params.get("actuator_gear", 0.064))
|
|
||||||
gear = (gear_pos + gear_neg) / 2.0
|
|
||||||
filter_tau = params.get("actuator_filter_tau", 0.03)
|
|
||||||
|
|
||||||
for act_el in root.iter("general"):
|
|
||||||
if act_el.get("name") == "motor":
|
|
||||||
act_el.set("gear", str(gear))
|
|
||||||
if filter_tau > 0:
|
|
||||||
act_el.set("dyntype", "filter")
|
|
||||||
act_el.set("dynprm", str(filter_tau))
|
|
||||||
else:
|
|
||||||
act_el.set("dyntype", "none")
|
|
||||||
act_el.set("dynprm", "1")
|
|
||||||
|
|
||||||
# ── Joint overrides ──────────────────────────────────────────
|
|
||||||
# Damping and friction are asymmetric + nonlinear → applied manually.
|
|
||||||
# Set MuJoCo damping & frictionloss to 0; we handle them in qfrc_applied.
|
|
||||||
armature = params.get("motor_armature", 0.0001)
|
|
||||||
|
|
||||||
for jnt in root.iter("joint"):
|
|
||||||
if jnt.get("name") == "motor_joint":
|
|
||||||
jnt.set("damping", "0")
|
|
||||||
jnt.set("armature", str(armature))
|
|
||||||
jnt.set("frictionloss", "0")
|
|
||||||
|
|
||||||
# ── Rotor mass override ──────────────────────────────────────
|
|
||||||
rotor_mass = params.get("rotor_mass", 0.012)
|
|
||||||
for geom in root.iter("geom"):
|
|
||||||
if geom.get("name") == "rotor_disk":
|
|
||||||
geom.set("mass", str(rotor_mass))
|
|
||||||
|
|
||||||
# Write temp file and load.
|
|
||||||
tmp_path = asset_path / "_tmp_motor_sysid.xml"
|
|
||||||
try:
|
|
||||||
tree.write(str(tmp_path), xml_declaration=True, encoding="unicode")
|
|
||||||
model = mujoco.MjModel.from_xml_path(str(tmp_path))
|
|
||||||
finally:
|
|
||||||
tmp_path.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# ── Action + asymmetry transforms ────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _transform_action(
|
|
||||||
action: float,
|
|
||||||
params: dict[str, float],
|
|
||||||
) -> float:
|
|
||||||
"""Apply bias, direction-dependent deadzone, and gear scaling.
|
|
||||||
|
|
||||||
The MuJoCo actuator has the *average* gear ratio. We rescale the
|
|
||||||
control signal so that ``ctrl × gear_avg ≈ action × gear_dir``.
|
|
||||||
"""
|
|
||||||
# Constant bias (H-bridge asymmetry).
|
|
||||||
action = action + params.get("action_bias", 0.0)
|
|
||||||
|
|
||||||
# Direction-dependent deadzone.
|
|
||||||
dz_pos = params.get("motor_deadzone_pos", params.get("motor_deadzone", 0.08))
|
|
||||||
dz_neg = params.get("motor_deadzone_neg", params.get("motor_deadzone", 0.08))
|
|
||||||
if action >= 0 and action < dz_pos:
|
|
||||||
return 0.0
|
|
||||||
if action < 0 and action > -dz_neg:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
# Direction-dependent gear scaling.
|
|
||||||
# MuJoCo model uses gear_avg; we rescale ctrl to get the right torque.
|
|
||||||
gear_pos = params.get("actuator_gear_pos", params.get("actuator_gear", 0.064))
|
|
||||||
gear_neg = params.get("actuator_gear_neg", params.get("actuator_gear", 0.064))
|
|
||||||
gear_avg = (gear_pos + gear_neg) / 2.0
|
|
||||||
if gear_avg < 1e-8:
|
|
||||||
return 0.0
|
|
||||||
gear_dir = gear_pos if action >= 0 else gear_neg
|
|
||||||
return action * (gear_dir / gear_avg)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_forces(
|
|
||||||
data: mujoco.MjData,
|
|
||||||
vel: float,
|
|
||||||
ctrl: float,
|
|
||||||
params: dict[str, float],
|
|
||||||
backlash_state: list[float] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Apply all manual forces: asymmetric friction, damping, and nonlinear terms.
|
|
||||||
|
|
||||||
Everything that MuJoCo can't represent with its symmetric joint model
|
|
||||||
is injected here via ``qfrc_applied``.
|
|
||||||
|
|
||||||
Forces applied (all oppose motion or reduce torque):
|
|
||||||
1. Asymmetric Coulomb friction (with Stribeck boost at low speed)
|
|
||||||
2. Asymmetric viscous damping
|
|
||||||
3. Quadratic velocity drag
|
|
||||||
4. Back-EMF torque reduction (proportional to |vel|)
|
|
||||||
|
|
||||||
Backlash:
|
|
||||||
If backlash_state is provided, it is a 1-element list [gap_pos].
|
|
||||||
gap_pos tracks the motor's position within the backlash deadband.
|
|
||||||
When inside the gap, no actuator torque is transmitted to the
|
|
||||||
output shaft — only friction forces act.
|
|
||||||
"""
|
|
||||||
torque = 0.0
|
|
||||||
|
|
||||||
# ── Gearbox backlash ──────────────────────────────────────────
|
|
||||||
# Model: the gear teeth have play of 2×halfwidth radians.
|
|
||||||
# We track where the motor is within that gap. When at the
|
|
||||||
# edge (contact), actuator torque passes through normally.
|
|
||||||
# When inside the gap, no actuator torque is transmitted.
|
|
||||||
backlash_hw = params.get("gearbox_backlash", 0.0)
|
|
||||||
actuator_torque_scale = 1.0 # 1.0 = full contact, 0.0 = in gap
|
|
||||||
|
|
||||||
if backlash_hw > 0 and backlash_state is not None:
|
|
||||||
# gap_pos: how far into the backlash gap we are.
|
|
||||||
# Range: [-backlash_hw, +backlash_hw]
|
|
||||||
# At ±backlash_hw, gears are in contact and torque transmits.
|
|
||||||
gap = backlash_state[0]
|
|
||||||
# Update gap position based on velocity.
|
|
||||||
dt_sub = data.model.opt.timestep
|
|
||||||
gap += vel * dt_sub
|
|
||||||
# Clamp to backlash range.
|
|
||||||
if gap > backlash_hw:
|
|
||||||
gap = backlash_hw
|
|
||||||
elif gap < -backlash_hw:
|
|
||||||
gap = -backlash_hw
|
|
||||||
|
|
||||||
backlash_state[0] = gap
|
|
||||||
|
|
||||||
# If not at contact edge, no torque transmitted.
|
|
||||||
if abs(gap) < backlash_hw - 1e-8:
|
|
||||||
actuator_torque_scale = 0.0
|
|
||||||
else:
|
|
||||||
actuator_torque_scale = 1.0
|
|
||||||
|
|
||||||
# ── 1. Coulomb friction (direction-dependent + Stribeck) ─────
|
|
||||||
fl_pos = params.get("motor_frictionloss_pos", params.get("motor_frictionloss", 0.03))
|
|
||||||
fl_neg = params.get("motor_frictionloss_neg", params.get("motor_frictionloss", 0.03))
|
|
||||||
stribeck_boost = params.get("stribeck_friction_boost", 0.0)
|
|
||||||
stribeck_vel = params.get("stribeck_vel", 2.0)
|
|
||||||
|
|
||||||
if abs(vel) > 1e-6:
|
|
||||||
fl = fl_pos if vel > 0 else fl_neg
|
|
||||||
# Stribeck: boost friction at low speed. Exponential decay.
|
|
||||||
if stribeck_boost > 0 and stribeck_vel > 0:
|
|
||||||
fl = fl * (1.0 + stribeck_boost * np.exp(-abs(vel) / stribeck_vel))
|
|
||||||
# Coulomb: constant magnitude, opposes motion.
|
|
||||||
torque -= np.sign(vel) * fl
|
|
||||||
|
|
||||||
# ── 2. Asymmetric viscous damping ────────────────────────────
|
|
||||||
damp_pos = params.get("motor_damping_pos", params.get("motor_damping", 0.003))
|
|
||||||
damp_neg = params.get("motor_damping_neg", params.get("motor_damping", 0.003))
|
|
||||||
damp = damp_pos if vel > 0 else damp_neg
|
|
||||||
torque -= damp * vel
|
|
||||||
|
|
||||||
# ── 3. Quadratic velocity drag ───────────────────────────────
|
|
||||||
visc_quad = params.get("viscous_quadratic", 0.0)
|
|
||||||
if visc_quad > 0:
|
|
||||||
torque -= visc_quad * vel * abs(vel)
|
|
||||||
|
|
||||||
# ── 4. Back-EMF torque reduction ─────────────────────────────
|
|
||||||
# At high speed, the motor's effective torque is reduced because
|
|
||||||
# back-EMF opposes the supply voltage. Modelled as a torque that
|
|
||||||
# opposes the control signal proportional to speed.
|
|
||||||
bemf = params.get("back_emf_gain", 0.0)
|
|
||||||
if bemf > 0 and abs(ctrl) > 1e-6:
|
|
||||||
# The reduction should oppose the actuator torque direction.
|
|
||||||
torque -= bemf * vel * np.sign(ctrl) * actuator_torque_scale
|
|
||||||
|
|
||||||
# ── 5. Scale actuator contribution by backlash state ─────────
|
|
||||||
# When in the backlash gap, MuJoCo's actuator force should not
|
|
||||||
# transmit. We cancel it by applying an opposing force.
|
|
||||||
if actuator_torque_scale < 1.0:
|
|
||||||
# The actuator_force from MuJoCo will be applied by mj_step.
|
|
||||||
# We need to counteract it. data.qfrc_actuator isn't set yet
|
|
||||||
# at this point (pre-step), so we zero the ctrl instead.
|
|
||||||
# This is handled in the rollout loop by zeroing ctrl.
|
|
||||||
pass
|
|
||||||
|
|
||||||
data.qfrc_applied[0] = max(-10.0, min(10.0, torque))
|
|
||||||
return actuator_torque_scale
|
|
||||||
|
|
||||||
|
|
||||||
# ── Simulation rollout ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def rollout(
|
|
||||||
asset_path: str | Path,
|
|
||||||
params: dict[str, float],
|
|
||||||
actions: np.ndarray,
|
|
||||||
sim_dt: float = 0.002,
|
|
||||||
substeps: int = 10,
|
|
||||||
) -> dict[str, np.ndarray]:
|
|
||||||
"""Open-loop replay of recorded actions.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
asset_path : motor asset directory
|
|
||||||
params : named parameter overrides
|
|
||||||
actions : (N,) normalised actions [-1, 1] from the recording
|
|
||||||
sim_dt : MuJoCo physics timestep
|
|
||||||
substeps : physics substeps per control step (ctrl_dt = sim_dt × substeps)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict with motor_angle (N,) and motor_vel (N,).
|
|
||||||
"""
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
model = _build_model(asset_path, params)
|
|
||||||
model.opt.timestep = sim_dt
|
|
||||||
data = mujoco.MjData(model)
|
|
||||||
mujoco.mj_resetData(model, data)
|
|
||||||
|
|
||||||
n = len(actions)
|
|
||||||
|
|
||||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
|
||||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
|
||||||
|
|
||||||
# Backlash state: [gap_position]. Starts at 0 (centered in gap).
|
|
||||||
backlash_state = [0.0]
|
|
||||||
|
|
||||||
for i in range(n):
|
|
||||||
ctrl = _transform_action(actions[i], params)
|
|
||||||
data.ctrl[0] = ctrl
|
|
||||||
|
|
||||||
for _ in range(substeps):
|
|
||||||
scale = _apply_forces(data, data.qvel[0], ctrl, params, backlash_state)
|
|
||||||
# If in backlash gap, zero ctrl so actuator torque doesn't transmit.
|
|
||||||
if scale < 1.0:
|
|
||||||
data.ctrl[0] = 0.0
|
|
||||||
else:
|
|
||||||
data.ctrl[0] = ctrl
|
|
||||||
mujoco.mj_step(model, data)
|
|
||||||
|
|
||||||
# Bail out on NaN/instability.
|
|
||||||
if not np.isfinite(data.qpos[0]) or abs(data.qvel[0]) > 1e4:
|
|
||||||
sim_motor_angle[i:] = np.nan
|
|
||||||
sim_motor_vel[i:] = np.nan
|
|
||||||
break
|
|
||||||
|
|
||||||
sim_motor_angle[i] = data.qpos[0]
|
|
||||||
sim_motor_vel[i] = data.qvel[0]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"motor_angle": sim_motor_angle,
|
|
||||||
"motor_vel": sim_motor_vel,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def windowed_rollout(
|
|
||||||
asset_path: str | Path,
|
|
||||||
params: dict[str, float],
|
|
||||||
recording: dict[str, np.ndarray],
|
|
||||||
window_duration: float = 0.5,
|
|
||||||
sim_dt: float = 0.002,
|
|
||||||
substeps: int = 10,
|
|
||||||
) -> dict[str, np.ndarray]:
|
|
||||||
"""Multiple-shooting rollout for motor-only sysid.
|
|
||||||
|
|
||||||
Splits the recording into short windows. Each window is initialised
|
|
||||||
from the real motor state, preventing error accumulation.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
dict with motor_angle (N,), motor_vel (N,), n_windows (int).
|
|
||||||
"""
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
model = _build_model(asset_path, params)
|
|
||||||
model.opt.timestep = sim_dt
|
|
||||||
data = mujoco.MjData(model)
|
|
||||||
|
|
||||||
times = recording["time"]
|
|
||||||
actions = recording["action"]
|
|
||||||
real_angle = recording["motor_angle"]
|
|
||||||
real_vel = recording["motor_vel"]
|
|
||||||
n = len(actions)
|
|
||||||
|
|
||||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
|
||||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
|
||||||
|
|
||||||
# Compute window boundaries.
|
|
||||||
t0 = times[0]
|
|
||||||
t_end = times[-1]
|
|
||||||
window_starts: list[int] = []
|
|
||||||
current_t = t0
|
|
||||||
while current_t < t_end:
|
|
||||||
idx = int(np.searchsorted(times, current_t))
|
|
||||||
idx = min(idx, n - 1)
|
|
||||||
window_starts.append(idx)
|
|
||||||
current_t += window_duration
|
|
||||||
|
|
||||||
n_windows = len(window_starts)
|
|
||||||
|
|
||||||
for w, w_start in enumerate(window_starts):
|
|
||||||
w_end = window_starts[w + 1] if w + 1 < n_windows else n
|
|
||||||
|
|
||||||
# Init from real state at window start.
|
|
||||||
mujoco.mj_resetData(model, data)
|
|
||||||
data.qpos[0] = real_angle[w_start]
|
|
||||||
data.qvel[0] = real_vel[w_start]
|
|
||||||
data.ctrl[:] = 0.0
|
|
||||||
mujoco.mj_forward(model, data)
|
|
||||||
|
|
||||||
# Backlash state resets each window (assume contact at start).
|
|
||||||
backlash_state = [0.0]
|
|
||||||
|
|
||||||
for i in range(w_start, w_end):
|
|
||||||
ctrl = _transform_action(actions[i], params)
|
|
||||||
data.ctrl[0] = ctrl
|
|
||||||
|
|
||||||
for _ in range(substeps):
|
|
||||||
scale = _apply_forces(data, data.qvel[0], ctrl, params, backlash_state)
|
|
||||||
if scale < 1.0:
|
|
||||||
data.ctrl[0] = 0.0
|
|
||||||
else:
|
|
||||||
data.ctrl[0] = ctrl
|
|
||||||
mujoco.mj_step(model, data)
|
|
||||||
|
|
||||||
# Bail out on NaN/instability — fill rest of window with last good.
|
|
||||||
if not np.isfinite(data.qpos[0]) or abs(data.qvel[0]) > 1e4:
|
|
||||||
sim_motor_angle[i:w_end] = sim_motor_angle[max(i-1, w_start)]
|
|
||||||
sim_motor_vel[i:w_end] = 0.0
|
|
||||||
break
|
|
||||||
|
|
||||||
sim_motor_angle[i] = data.qpos[0]
|
|
||||||
sim_motor_vel[i] = data.qvel[0]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"motor_angle": sim_motor_angle,
|
|
||||||
"motor_vel": sim_motor_vel,
|
|
||||||
"n_windows": n_windows,
|
|
||||||
}
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
"""Visualise motor system identification — real vs simulated trajectories.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m src.sysid.motor.visualize \
|
|
||||||
--recording assets/motor/recordings/motor_capture_YYYYMMDD_HHMMSS.npz
|
|
||||||
|
|
||||||
# With tuned result:
|
|
||||||
python -m src.sysid.motor.visualize \
|
|
||||||
--recording <file>.npz \
|
|
||||||
--result assets/motor/motor_sysid_result.json
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
log = structlog.get_logger()
|
|
||||||
|
|
||||||
_DEFAULT_ASSET = "assets/motor"
|
|
||||||
|
|
||||||
|
|
||||||
def visualize(
|
|
||||||
asset_path: str | Path = _DEFAULT_ASSET,
|
|
||||||
recording_path: str | Path = "",
|
|
||||||
result_path: str | Path | None = None,
|
|
||||||
sim_dt: float = 0.002,
|
|
||||||
substeps: int = 10,
|
|
||||||
window_duration: float = 0.5,
|
|
||||||
save_path: str | Path | None = None,
|
|
||||||
show: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""Generate 3-panel comparison plot: angle, velocity, action."""
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from src.sysid.motor.rollout import (
|
|
||||||
MOTOR_PARAMS,
|
|
||||||
defaults_vector,
|
|
||||||
params_to_dict,
|
|
||||||
rollout,
|
|
||||||
windowed_rollout,
|
|
||||||
)
|
|
||||||
|
|
||||||
asset_path = Path(asset_path).resolve()
|
|
||||||
recording = dict(np.load(recording_path))
|
|
||||||
|
|
||||||
t = recording["time"]
|
|
||||||
actions = recording["action"]
|
|
||||||
|
|
||||||
# ── Simulate with default parameters ─────────────────────────
|
|
||||||
default_params = params_to_dict(defaults_vector(MOTOR_PARAMS), MOTOR_PARAMS)
|
|
||||||
log.info("simulating_default_params", windowed=window_duration > 0)
|
|
||||||
|
|
||||||
if window_duration > 0:
|
|
||||||
sim_default = windowed_rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=default_params,
|
|
||||||
recording=recording,
|
|
||||||
window_duration=window_duration,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sim_default = rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=default_params,
|
|
||||||
actions=actions,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Simulate with tuned parameters (if available) ────────────
|
|
||||||
sim_tuned = None
|
|
||||||
tuned_cost = None
|
|
||||||
|
|
||||||
if result_path is not None:
|
|
||||||
result_path = Path(result_path)
|
|
||||||
else:
|
|
||||||
# Auto-detect.
|
|
||||||
auto = asset_path / "motor_sysid_result.json"
|
|
||||||
if auto.exists():
|
|
||||||
result_path = auto
|
|
||||||
|
|
||||||
if result_path is not None and result_path.exists():
|
|
||||||
result = json.loads(result_path.read_text())
|
|
||||||
tuned_params = result.get("best_params", {})
|
|
||||||
tuned_cost = result.get("best_cost")
|
|
||||||
log.info("simulating_tuned_params", cost=tuned_cost)
|
|
||||||
|
|
||||||
if window_duration > 0:
|
|
||||||
sim_tuned = windowed_rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=tuned_params,
|
|
||||||
recording=recording,
|
|
||||||
window_duration=window_duration,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sim_tuned = rollout(
|
|
||||||
asset_path=asset_path,
|
|
||||||
params=tuned_params,
|
|
||||||
actions=actions,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Plot ─────────────────────────────────────────────────────
|
|
||||||
fig, axes = plt.subplots(3, 1, figsize=(14, 8), sharex=True)
|
|
||||||
|
|
||||||
# Motor angle.
|
|
||||||
ax = axes[0]
|
|
||||||
ax.plot(t, np.degrees(recording["motor_angle"]), "k-", lw=1.2, alpha=0.8, label="Real")
|
|
||||||
ax.plot(t, np.degrees(sim_default["motor_angle"]), "--", color="#d62728", lw=1.0, alpha=0.7, label="Sim (default)")
|
|
||||||
if sim_tuned is not None:
|
|
||||||
ax.plot(t, np.degrees(sim_tuned["motor_angle"]), "--", color="#2ca02c", lw=1.0, alpha=0.7, label="Sim (tuned)")
|
|
||||||
ax.set_ylabel("Motor Angle (°)")
|
|
||||||
ax.legend(loc="upper right", fontsize=8)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# Motor velocity.
|
|
||||||
ax = axes[1]
|
|
||||||
ax.plot(t, recording["motor_vel"], "k-", lw=1.2, alpha=0.8, label="Real")
|
|
||||||
ax.plot(t, sim_default["motor_vel"], "--", color="#d62728", lw=1.0, alpha=0.7, label="Sim (default)")
|
|
||||||
if sim_tuned is not None:
|
|
||||||
ax.plot(t, sim_tuned["motor_vel"], "--", color="#2ca02c", lw=1.0, alpha=0.7, label="Sim (tuned)")
|
|
||||||
ax.set_ylabel("Motor Velocity (rad/s)")
|
|
||||||
ax.legend(loc="upper right", fontsize=8)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# Action.
|
|
||||||
ax = axes[2]
|
|
||||||
ax.plot(t, actions, "b-", lw=0.8, alpha=0.6)
|
|
||||||
ax.set_ylabel("Action (norm)")
|
|
||||||
ax.set_xlabel("Time (s)")
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
ax.set_ylim(-1.1, 1.1)
|
|
||||||
|
|
||||||
# Title.
|
|
||||||
title = "Motor System Identification — Real vs Simulated"
|
|
||||||
if tuned_cost is not None:
|
|
||||||
from src.sysid.motor.optimize import cost_function
|
|
||||||
|
|
||||||
orig_cost = cost_function(
|
|
||||||
defaults_vector(MOTOR_PARAMS),
|
|
||||||
recording,
|
|
||||||
asset_path,
|
|
||||||
MOTOR_PARAMS,
|
|
||||||
sim_dt=sim_dt,
|
|
||||||
substeps=substeps,
|
|
||||||
window_duration=window_duration,
|
|
||||||
)
|
|
||||||
title += f"\nDefault cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
|
||||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
|
||||||
title += f" ({improvement:+.1f}%)"
|
|
||||||
|
|
||||||
fig.suptitle(title, fontsize=12)
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
if save_path:
|
|
||||||
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
|
||||||
log.info("figure_saved", path=str(save_path))
|
|
||||||
|
|
||||||
if show:
|
|
||||||
plt.show()
|
|
||||||
else:
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Visualise motor system identification results."
|
|
||||||
)
|
|
||||||
parser.add_argument("--asset-path", type=str, default=_DEFAULT_ASSET)
|
|
||||||
parser.add_argument("--recording", type=str, required=True, help=".npz file")
|
|
||||||
parser.add_argument("--result", type=str, default=None, help="sysid result JSON")
|
|
||||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
|
||||||
parser.add_argument("--substeps", type=int, default=10)
|
|
||||||
parser.add_argument("--window-duration", type=float, default=0.5)
|
|
||||||
parser.add_argument("--save", type=str, default=None, help="Save figure path")
|
|
||||||
parser.add_argument("--no-show", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
visualize(
|
|
||||||
asset_path=args.asset_path,
|
|
||||||
recording_path=args.recording,
|
|
||||||
result_path=args.result,
|
|
||||||
sim_dt=args.sim_dt,
|
|
||||||
substeps=args.substeps,
|
|
||||||
window_duration=args.window_duration,
|
|
||||||
save_path=args.save,
|
|
||||||
show=not args.no_show,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -7,12 +7,13 @@ the simulated trajectory for comparison with the real recording.
|
|||||||
This module is the inner loop of the CMA-ES optimizer: it is called once
|
This module is the inner loop of the CMA-ES optimizer: it is called once
|
||||||
per candidate parameter vector per generation.
|
per candidate parameter vector per generation.
|
||||||
|
|
||||||
Motor parameters are **locked** from the motor-only sysid result.
|
Motor parameters are **locked** from the unified sysid result.
|
||||||
The optimizer only fits
|
The optimizer only fits
|
||||||
pendulum/arm inertial parameters, pendulum joint dynamics, and
|
pendulum/arm inertial parameters, pendulum joint dynamics, and
|
||||||
``ctrl_limit``. The asymmetric motor model (deadzone, gear compensation,
|
``ctrl_limit``. The asymmetric motor model (bias, deadzone, gear
|
||||||
Coulomb friction, viscous damping, quadratic drag, back-EMF) is applied
|
compensation, Coulomb + Stribeck friction, viscous damping) is applied
|
||||||
via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()``.
|
via ``ActuatorConfig.transform_ctrl()`` and ``compute_motor_force()`` —
|
||||||
|
the same code the training runners use, so sim == sysid by construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -32,23 +33,25 @@ from src.runners.mujoco import ActuatorLimits, load_mujoco_model
|
|||||||
from src.sysid._urdf import patch_link_inertials
|
from src.sysid._urdf import patch_link_inertials
|
||||||
|
|
||||||
|
|
||||||
# ── Locked motor parameters (from motor-only sysid) ─────────────────
|
# ── Locked motor parameters (from the unified sysid) ────────────────
|
||||||
# These are FIXED and not optimised. They come from the 12-param model
|
# These are FIXED and not optimised. They come from the unified
|
||||||
# in robot.yaml (from motor-only sysid, cost 0.862).
|
# 28-param sysid run (assets/rotary_cartpole/sysid_result.json,
|
||||||
|
# cost 0.925) — Stribeck friction + action bias + ~96 ms motor lag.
|
||||||
|
|
||||||
LOCKED_MOTOR_PARAMS: dict[str, float] = {
|
LOCKED_MOTOR_PARAMS: dict[str, float] = {
|
||||||
"actuator_gear_pos": 0.424182,
|
"actuator_gear_pos": 0.846499,
|
||||||
"actuator_gear_neg": 0.425031,
|
"actuator_gear_neg": 1.183733,
|
||||||
"actuator_filter_tau": 0.00503506,
|
"actuator_filter_tau": 0.096263,
|
||||||
"motor_damping_pos": 0.00202682,
|
"motor_damping_pos": 0.013165,
|
||||||
"motor_damping_neg": 0.0146651,
|
"motor_damping_neg": 0.015452,
|
||||||
"motor_armature": 0.00277342,
|
"motor_armature": 0.001676,
|
||||||
"motor_frictionloss_pos": 0.0573282,
|
"motor_frictionloss_pos": 0.014244,
|
||||||
"motor_frictionloss_neg": 0.0533549,
|
"motor_frictionloss_neg": 0.001005,
|
||||||
"viscous_quadratic": 0.000285329,
|
"stribeck_friction_boost": 0.068594,
|
||||||
"back_emf_gain": 0.00675809,
|
"stribeck_vel": 5.279594,
|
||||||
"motor_deadzone_pos": 0.141291,
|
"motor_deadzone_pos": 0.181097,
|
||||||
"motor_deadzone_neg": 0.0780148,
|
"motor_deadzone_neg": 0.202072,
|
||||||
|
"action_bias": 0.056566,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -190,26 +193,34 @@ def _build_model(
|
|||||||
act_cfg = robot_yaml["actuators"][0]
|
act_cfg = robot_yaml["actuators"][0]
|
||||||
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
|
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
|
||||||
|
|
||||||
|
# The fitted ctrl_limit overrides the YAML ctrl_range so the rollout
|
||||||
|
# saturates at exactly the identified PWM bound.
|
||||||
|
if "ctrl_limit" in params:
|
||||||
|
ctrl_lo, ctrl_hi = -params["ctrl_limit"], params["ctrl_limit"]
|
||||||
|
|
||||||
actuator = ActuatorConfig(
|
actuator = ActuatorConfig(
|
||||||
joint=act_cfg["joint"],
|
joint=act_cfg["joint"],
|
||||||
type="motor",
|
type="motor",
|
||||||
gear=(gear_pos, gear_neg),
|
gear=(gear_pos, gear_neg),
|
||||||
ctrl_range=(ctrl_lo, ctrl_hi),
|
ctrl_range=(ctrl_lo, ctrl_hi),
|
||||||
deadzone=(
|
deadzone=(
|
||||||
motor_params.get("motor_deadzone_pos", 0.141),
|
motor_params.get("motor_deadzone_pos", 0.181),
|
||||||
motor_params.get("motor_deadzone_neg", 0.078),
|
motor_params.get("motor_deadzone_neg", 0.202),
|
||||||
),
|
),
|
||||||
damping=(
|
damping=(
|
||||||
motor_params.get("motor_damping_pos", 0.002),
|
motor_params.get("motor_damping_pos", 0.013),
|
||||||
motor_params.get("motor_damping_neg", 0.015),
|
motor_params.get("motor_damping_neg", 0.015),
|
||||||
),
|
),
|
||||||
frictionloss=(
|
frictionloss=(
|
||||||
motor_params.get("motor_frictionloss_pos", 0.057),
|
motor_params.get("motor_frictionloss_pos", 0.014),
|
||||||
motor_params.get("motor_frictionloss_neg", 0.053),
|
motor_params.get("motor_frictionloss_neg", 0.001),
|
||||||
),
|
),
|
||||||
filter_tau=motor_params.get("actuator_filter_tau", 0.005),
|
filter_tau=motor_params.get("actuator_filter_tau", 0.096),
|
||||||
viscous_quadratic=motor_params.get("viscous_quadratic", 0.000285),
|
viscous_quadratic=motor_params.get("viscous_quadratic", 0.0),
|
||||||
back_emf_gain=motor_params.get("back_emf_gain", 0.00676),
|
back_emf_gain=motor_params.get("back_emf_gain", 0.0),
|
||||||
|
stribeck_friction_boost=motor_params.get("stribeck_friction_boost", 0.0),
|
||||||
|
stribeck_vel=motor_params.get("stribeck_vel", 2.0),
|
||||||
|
action_bias=motor_params.get("action_bias", 0.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
robot = RobotConfig(
|
robot = RobotConfig(
|
||||||
@@ -276,7 +287,6 @@ def rollout(
|
|||||||
mujoco.mj_resetData(model, data)
|
mujoco.mj_resetData(model, data)
|
||||||
|
|
||||||
n = len(actions)
|
n = len(actions)
|
||||||
ctrl_limit = params.get("ctrl_limit", 0.588)
|
|
||||||
|
|
||||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||||
@@ -286,8 +296,8 @@ def rollout(
|
|||||||
limits = ActuatorLimits(model)
|
limits = ActuatorLimits(model)
|
||||||
|
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
action = max(-ctrl_limit, min(ctrl_limit, float(actions[i])))
|
# transform_ctrl clips to the (fitted) ctrl_range internally.
|
||||||
ctrl = actuator.transform_ctrl(action)
|
ctrl = actuator.transform_ctrl(float(actions[i]))
|
||||||
data.ctrl[0] = ctrl
|
data.ctrl[0] = ctrl
|
||||||
|
|
||||||
for _ in range(substeps):
|
for _ in range(substeps):
|
||||||
@@ -378,7 +388,6 @@ def windowed_rollout(
|
|||||||
window_starts.append(idx)
|
window_starts.append(idx)
|
||||||
current_t += window_duration
|
current_t += window_duration
|
||||||
|
|
||||||
ctrl_limit = params.get("ctrl_limit", 0.588)
|
|
||||||
n_windows = len(window_starts)
|
n_windows = len(window_starts)
|
||||||
|
|
||||||
for w, w_start in enumerate(window_starts):
|
for w, w_start in enumerate(window_starts):
|
||||||
@@ -393,8 +402,8 @@ def windowed_rollout(
|
|||||||
mujoco.mj_forward(model, data)
|
mujoco.mj_forward(model, data)
|
||||||
|
|
||||||
for i in range(w_start, w_end):
|
for i in range(w_start, w_end):
|
||||||
action = max(-ctrl_limit, min(ctrl_limit, float(actions[i])))
|
# transform_ctrl clips to the (fitted) ctrl_range internally.
|
||||||
ctrl = actuator.transform_ctrl(action)
|
ctrl = actuator.transform_ctrl(float(actions[i]))
|
||||||
data.ctrl[0] = ctrl
|
data.ctrl[0] = ctrl
|
||||||
|
|
||||||
for _ in range(substeps):
|
for _ in range(substeps):
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ def _run_sim(
|
|||||||
window_duration: float,
|
window_duration: float,
|
||||||
sim_dt: float,
|
sim_dt: float,
|
||||||
substeps: int,
|
substeps: int,
|
||||||
motor_params: dict[str, float],
|
|
||||||
) -> dict[str, np.ndarray]:
|
) -> dict[str, np.ndarray]:
|
||||||
"""Run windowed or open-loop rollout depending on window_duration."""
|
"""Run windowed or open-loop rollout depending on window_duration."""
|
||||||
from src.sysid.rollout import rollout, windowed_rollout
|
from src.sysid.rollout import rollout, windowed_rollout
|
||||||
@@ -44,11 +43,11 @@ def _run_sim(
|
|||||||
return windowed_rollout(
|
return windowed_rollout(
|
||||||
robot_path=robot_path, params=params, recording=recording,
|
robot_path=robot_path, params=params, recording=recording,
|
||||||
window_duration=window_duration, sim_dt=sim_dt,
|
window_duration=window_duration, sim_dt=sim_dt,
|
||||||
substeps=substeps, motor_params=motor_params,
|
substeps=substeps,
|
||||||
)
|
)
|
||||||
return rollout(
|
return rollout(
|
||||||
robot_path=robot_path, params=params, actions=recording["action"],
|
robot_path=robot_path, params=params, actions=recording["action"],
|
||||||
substeps=substeps, motor_params=motor_params,
|
substeps=substeps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,7 +65,6 @@ def visualize(
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from src.sysid.rollout import (
|
from src.sysid.rollout import (
|
||||||
LOCKED_MOTOR_PARAMS,
|
|
||||||
ROTARY_CARTPOLE_PARAMS,
|
ROTARY_CARTPOLE_PARAMS,
|
||||||
defaults_vector,
|
defaults_vector,
|
||||||
params_to_dict,
|
params_to_dict,
|
||||||
@@ -75,12 +73,10 @@ def visualize(
|
|||||||
robot_path = Path(robot_path).resolve()
|
robot_path = Path(robot_path).resolve()
|
||||||
recording = dict(np.load(recording_path))
|
recording = dict(np.load(recording_path))
|
||||||
|
|
||||||
motor_params = LOCKED_MOTOR_PARAMS
|
|
||||||
|
|
||||||
sim_kwargs = dict(
|
sim_kwargs = dict(
|
||||||
robot_path=robot_path, recording=recording,
|
robot_path=robot_path, recording=recording,
|
||||||
window_duration=window_duration, sim_dt=sim_dt,
|
window_duration=window_duration, sim_dt=sim_dt,
|
||||||
substeps=substeps, motor_params=motor_params,
|
substeps=substeps,
|
||||||
)
|
)
|
||||||
|
|
||||||
t = recording["time"]
|
t = recording["time"]
|
||||||
@@ -172,7 +168,6 @@ def visualize(
|
|||||||
sim_dt=sim_dt,
|
sim_dt=sim_dt,
|
||||||
substeps=substeps,
|
substeps=substeps,
|
||||||
window_duration=window_duration,
|
window_duration=window_duration,
|
||||||
motor_params=motor_params,
|
|
||||||
)
|
)
|
||||||
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class TrainerConfig:
|
|||||||
clip_ratio: float = 0.2
|
clip_ratio: float = 0.2
|
||||||
value_loss_scale: float = 0.5
|
value_loss_scale: float = 0.5
|
||||||
entropy_loss_scale: float = 0.01
|
entropy_loss_scale: float = 0.01
|
||||||
|
kl_threshold: float = 0.01 # KL-adaptive LR target; 0 = fixed LR
|
||||||
|
|
||||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||||
|
|
||||||
@@ -48,6 +49,10 @@ class TrainerConfig:
|
|||||||
record_video_every: int = 10_000 # 0 = disabled
|
record_video_every: int = 10_000 # 0 = disabled
|
||||||
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||||
|
|
||||||
|
# History encoder (implicit adaptation). The window size comes from
|
||||||
|
# the runner (runner.history_length) — single source of truth.
|
||||||
|
embedding_dim: int = 32 # history encoder output dimension
|
||||||
|
|
||||||
|
|
||||||
# ── Video-recording trainer ──────────────────────────────────────────
|
# ── Video-recording trainer ──────────────────────────────────────────
|
||||||
|
|
||||||
@@ -99,13 +104,18 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
else:
|
else:
|
||||||
states = next_states
|
states = next_states
|
||||||
|
|
||||||
# Periodic video recording
|
# Periodic video recording. Recording steps the (shared) envs,
|
||||||
|
# so it returns a freshly reset observation — the training loop
|
||||||
|
# MUST continue from it, otherwise the recorded transitions no
|
||||||
|
# longer match the actual env state.
|
||||||
if (
|
if (
|
||||||
self._tcfg
|
self._tcfg
|
||||||
and self._tcfg.record_video_every > 0
|
and self._tcfg.record_video_every > 0
|
||||||
and (timestep + 1) % self._tcfg.record_video_every == 0
|
and (timestep + 1) % self._tcfg.record_video_every == 0
|
||||||
):
|
):
|
||||||
self._record_video(timestep + 1)
|
fresh_states = self._record_video(timestep + 1)
|
||||||
|
if fresh_states is not None:
|
||||||
|
states = fresh_states
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────
|
# ── helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -117,12 +127,22 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
# SerialRunner has dt but no substeps — dt *is* the control period.
|
# SerialRunner has dt but no substeps — dt *is* the control period.
|
||||||
return max(1, int(round(1.0 / (dt * substeps))))
|
return max(1, int(round(1.0 / (dt * substeps))))
|
||||||
|
|
||||||
def _record_video(self, timestep: int) -> None:
|
def _record_video(self, timestep: int) -> torch.Tensor | None:
|
||||||
|
"""Record an eval episode and upload it to ClearML.
|
||||||
|
|
||||||
|
Returns the freshly reset observation the training loop should
|
||||||
|
continue from (the recording steps the shared envs), or ``None``
|
||||||
|
if even the final reset failed.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
import imageio.v3 as iio
|
import imageio.v3 as iio
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
iio = None
|
||||||
|
|
||||||
|
# Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
|
||||||
|
# failure crash training — log it and skip the video.
|
||||||
|
if iio is not None:
|
||||||
|
try:
|
||||||
fps = self._get_fps()
|
fps = self._get_fps()
|
||||||
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||||
frames: list[np.ndarray] = []
|
frames: list[np.ndarray] = []
|
||||||
@@ -150,8 +170,18 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
"Training Video", f"step_{timestep}",
|
"Training Video", f"step_{timestep}",
|
||||||
local_path=path, iteration=timestep,
|
local_path=path, iteration=timestep,
|
||||||
)
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("video_recording_failed", timestep=timestep, error=str(exc))
|
||||||
|
|
||||||
self.env.reset()
|
# Always leave the envs freshly reset and hand the new observation
|
||||||
|
# back to the training loop.
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
states, _ = self.env.reset()
|
||||||
|
return states
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("post_video_reset_failed", timestep=timestep, error=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ── Main trainer ─────────────────────────────────────────────────────
|
# ── Main trainer ─────────────────────────────────────────────────────
|
||||||
@@ -173,6 +203,12 @@ class Trainer:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Determine raw obs dim (without history augmentation) and the
|
||||||
|
# history window size — both come from the runner so the model
|
||||||
|
# always matches the observation layout it produces.
|
||||||
|
raw_obs_dim = self.runner.env.observation_space.shape[0]
|
||||||
|
history_length = getattr(self.runner.config, "history_length", 0)
|
||||||
|
|
||||||
self.model = SharedMLP(
|
self.model = SharedMLP(
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
action_space=act_space,
|
action_space=act_space,
|
||||||
@@ -181,6 +217,9 @@ class Trainer:
|
|||||||
initial_log_std=self.config.initial_log_std,
|
initial_log_std=self.config.initial_log_std,
|
||||||
min_log_std=self.config.min_log_std,
|
min_log_std=self.config.min_log_std,
|
||||||
max_log_std=self.config.max_log_std,
|
max_log_std=self.config.max_log_std,
|
||||||
|
history_length=history_length,
|
||||||
|
raw_obs_dim=raw_obs_dim,
|
||||||
|
embedding_dim=self.config.embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
models = {"policy": self.model, "value": self.model}
|
models = {"policy": self.model, "value": self.model}
|
||||||
@@ -196,11 +235,20 @@ class Trainer:
|
|||||||
"ratio_clip": self.config.clip_ratio,
|
"ratio_clip": self.config.clip_ratio,
|
||||||
"value_loss_scale": self.config.value_loss_scale,
|
"value_loss_scale": self.config.value_loss_scale,
|
||||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||||
|
# Truncation (time limit) must bootstrap from the value function;
|
||||||
|
# without this the value target is biased at every max_steps cut.
|
||||||
|
"time_limit_bootstrap": True,
|
||||||
"state_preprocessor": RunningStandardScaler,
|
"state_preprocessor": RunningStandardScaler,
|
||||||
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
||||||
"value_preprocessor": RunningStandardScaler,
|
"value_preprocessor": RunningStandardScaler,
|
||||||
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
||||||
})
|
})
|
||||||
|
if self.config.kl_threshold > 0:
|
||||||
|
from skrl.resources.schedulers.torch import KLAdaptiveLR
|
||||||
|
agent_cfg["learning_rate_scheduler"] = KLAdaptiveLR
|
||||||
|
agent_cfg["learning_rate_scheduler_kwargs"] = {
|
||||||
|
"kl_threshold": self.config.kl_threshold,
|
||||||
|
}
|
||||||
# Wire up logging frequency: write_interval is in timesteps.
|
# Wire up logging frequency: write_interval is in timesteps.
|
||||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||||
|
|||||||
7
tests/conftest.py
Normal file
7
tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make `src.*` importable regardless of pytest invocation directory.
|
||||||
|
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||||
|
if _PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
79
tests/test_reward.py
Normal file
79
tests/test_reward.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Reward design tests — balancing must strictly dominate spinning."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.envs.rotary_cartpole import (
|
||||||
|
RotaryCartPoleConfig,
|
||||||
|
RotaryCartPoleEnv,
|
||||||
|
RotaryCartPoleState,
|
||||||
|
)
|
||||||
|
|
||||||
|
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||||
|
|
||||||
|
|
||||||
|
def _env() -> RotaryCartPoleEnv:
|
||||||
|
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||||
|
|
||||||
|
|
||||||
|
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
|
||||||
|
t = lambda v: torch.tensor([float(v)])
|
||||||
|
return RotaryCartPoleState(
|
||||||
|
motor_angle=t(motor), motor_vel=t(motor_vel),
|
||||||
|
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
|
||||||
|
a = torch.tensor([[float(action)]])
|
||||||
|
pa = torch.tensor([[float(prev_action)]])
|
||||||
|
return float(env.compute_rewards(state, a, pa)[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_balancing_beats_spinning_through_upright():
|
||||||
|
env = _env()
|
||||||
|
balanced = _state(pend=math.pi, pend_vel=0.0)
|
||||||
|
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
|
||||||
|
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
|
||||||
|
|
||||||
|
|
||||||
|
def test_average_spin_cycle_reward_below_balance():
|
||||||
|
"""Mean reward over a full revolution at high speed << balanced reward."""
|
||||||
|
env = _env()
|
||||||
|
angles = torch.linspace(0, 2 * math.pi, 32)
|
||||||
|
spin_rewards = [
|
||||||
|
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
|
||||||
|
]
|
||||||
|
mean_spin = sum(spin_rewards) / len(spin_rewards)
|
||||||
|
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
|
||||||
|
assert balanced > 3.0 * mean_spin
|
||||||
|
|
||||||
|
|
||||||
|
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
|
||||||
|
env = _env()
|
||||||
|
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
|
||||||
|
assert _reward(env, over_limit) == -10.0
|
||||||
|
assert bool(env.compute_terminations(over_limit)[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_rate_penalty_reduces_reward():
|
||||||
|
env = _env()
|
||||||
|
s = _state(pend=math.pi)
|
||||||
|
smooth = _reward(env, s, action=0.5, prev_action=0.5)
|
||||||
|
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
|
||||||
|
assert smooth > jerky
|
||||||
|
assert smooth - jerky == pytest.approx(
|
||||||
|
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initial_state_ranges_widen_pendulum_only():
|
||||||
|
env = _env()
|
||||||
|
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
|
||||||
|
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
|
||||||
|
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||||
|
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||||
|
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()
|
||||||
125
tests/test_robot_config.py
Normal file
125
tests/test_robot_config.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Robot config loading + motor model unit tests."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.robot import ActuatorConfig, load_robot_config
|
||||||
|
|
||||||
|
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Loading ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_canonical_robot_yaml_loads_full_motor_model():
|
||||||
|
robot = load_robot_config(ROBOT_DIR)
|
||||||
|
act = robot.actuators[0]
|
||||||
|
|
||||||
|
assert act.has_motor_model
|
||||||
|
# Tuned (unified sysid) values must survive the round-trip.
|
||||||
|
assert act.filter_tau == pytest.approx(0.096263)
|
||||||
|
assert act.stribeck_friction_boost == pytest.approx(0.068594)
|
||||||
|
assert act.stribeck_vel == pytest.approx(5.279594)
|
||||||
|
assert act.action_bias == pytest.approx(0.056566)
|
||||||
|
assert act.gear == pytest.approx((0.846499, 1.183733))
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
|
||||||
|
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
|
||||||
|
(tmp_path / "robot.yaml").write_text(
|
||||||
|
"urdf: dummy.urdf\n"
|
||||||
|
"actuators:\n"
|
||||||
|
" - joint: j\n"
|
||||||
|
" type: motor\n"
|
||||||
|
" gear: [1.0, 1.0]\n"
|
||||||
|
" some_future_field: 42\n"
|
||||||
|
)
|
||||||
|
robot = load_robot_config(tmp_path) # must not raise
|
||||||
|
assert robot.actuators[0].joint == "j"
|
||||||
|
|
||||||
|
|
||||||
|
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def act() -> ActuatorConfig:
|
||||||
|
return ActuatorConfig(
|
||||||
|
joint="m",
|
||||||
|
gear=(0.8, 1.2),
|
||||||
|
ctrl_range=(-0.6, 0.6),
|
||||||
|
deadzone=(0.15, 0.20),
|
||||||
|
frictionloss=(0.014, 0.001),
|
||||||
|
damping=(0.013, 0.015),
|
||||||
|
stribeck_friction_boost=0.07,
|
||||||
|
stribeck_vel=5.0,
|
||||||
|
action_bias=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_ctrl_clips_to_ctrl_range(act):
|
||||||
|
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
|
||||||
|
out = act.transform_ctrl(1.0)
|
||||||
|
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
|
||||||
|
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
|
||||||
|
assert act.transform_ctrl(0.05) == 0.0
|
||||||
|
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
|
||||||
|
assert act.transform_ctrl(-0.15) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
|
||||||
|
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
|
||||||
|
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
|
||||||
|
assert pos == pytest.approx(0.55 * 0.8)
|
||||||
|
assert neg == pytest.approx(-0.45 * 1.2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_action_matches_transform_ctrl_elementwise(act):
|
||||||
|
vals = torch.linspace(-1.2, 1.2, 49)
|
||||||
|
batched = act.transform_action(vals.clone())
|
||||||
|
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
|
||||||
|
assert torch.allclose(batched, scalar, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_friction_opposes_motion(act):
|
||||||
|
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
|
||||||
|
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
|
||||||
|
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stribeck_boost_decays_with_speed(act):
|
||||||
|
"""Friction torque magnitude (minus damping) is higher near standstill."""
|
||||||
|
no_strb = ActuatorConfig(
|
||||||
|
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||||
|
damping=(0.0, 0.0),
|
||||||
|
)
|
||||||
|
with_strb = ActuatorConfig(
|
||||||
|
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||||
|
damping=(0.0, 0.0),
|
||||||
|
stribeck_friction_boost=0.07, stribeck_vel=5.0,
|
||||||
|
)
|
||||||
|
v_slow, v_fast = 0.1, 50.0
|
||||||
|
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
|
||||||
|
no_strb.compute_motor_force(v_slow, 0.0))
|
||||||
|
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
|
||||||
|
no_strb.compute_motor_force(v_fast, 0.0))
|
||||||
|
|
||||||
|
assert extra_slow == pytest.approx(
|
||||||
|
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
|
||||||
|
assert extra_fast < extra_slow
|
||||||
|
assert extra_fast == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_friction_scale_dr_multiplies_friction(act):
|
||||||
|
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
|
||||||
|
# damping_scale=0 isolates the friction term
|
||||||
|
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
|
||||||
|
assert doubled == pytest.approx(2.0 * base)
|
||||||
173
tests/test_runner.py
Normal file
173
tests/test_runner.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""Runner integration tests — DR, history, action delay, init randomization.
|
||||||
|
|
||||||
|
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
|
||||||
|
is skipped when JAX isn't installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.registry import build_env
|
||||||
|
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||||
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||||
|
|
||||||
|
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||||
|
|
||||||
|
DR = {
|
||||||
|
"qpos_noise_std": 0.01,
|
||||||
|
"qvel_noise_std": 0.5,
|
||||||
|
"action_delay_steps": [0, 2],
|
||||||
|
"friction_scale": [0.6, 1.6],
|
||||||
|
"damping_scale": [0.6, 1.6],
|
||||||
|
"torque_scale": [0.85, 1.15],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
|
||||||
|
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||||
|
cfg = MuJoCoRunnerConfig(
|
||||||
|
num_envs=num_envs,
|
||||||
|
device="cpu",
|
||||||
|
history_length=history_length,
|
||||||
|
domain_rand=domain_rand or {},
|
||||||
|
)
|
||||||
|
return MuJoCoRunner(env=env, config=cfg)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Observation layout ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_obs_is_raw_plus_history():
|
||||||
|
runner = _runner(num_envs=2, history_length=3)
|
||||||
|
raw_dim = runner.env.observation_space.shape[0] # 6
|
||||||
|
step_dim = raw_dim + 1 # + action
|
||||||
|
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
|
||||||
|
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||||
|
# Fresh history must be zero.
|
||||||
|
assert torch.all(obs[:, raw_dim:] == 0)
|
||||||
|
|
||||||
|
actions = torch.full((2, 1), 0.3)
|
||||||
|
obs, rewards, term, trunc, info = runner.step(actions)
|
||||||
|
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||||
|
assert rewards.shape == (2, 1)
|
||||||
|
# Newest history slot holds the commanded action.
|
||||||
|
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_history_keeps_plain_obs():
|
||||||
|
runner = _runner(num_envs=2, history_length=0)
|
||||||
|
assert runner.observation_space.shape[0] == 6
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Domain randomization ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_dr_scales_sampled_within_ranges_and_resampled():
|
||||||
|
runner = _runner(num_envs=16, domain_rand=DR)
|
||||||
|
runner.reset()
|
||||||
|
for name, (lo, hi) in (
|
||||||
|
("friction_scale", (0.6, 1.6)),
|
||||||
|
("damping_scale", (0.6, 1.6)),
|
||||||
|
("torque_scale", (0.85, 1.15)),
|
||||||
|
):
|
||||||
|
vals = runner._dr_scales[name]
|
||||||
|
assert torch.all(vals >= lo) and torch.all(vals <= hi)
|
||||||
|
# 16 independent uniform samples are never all identical.
|
||||||
|
assert vals.std() > 0
|
||||||
|
|
||||||
|
before = runner._dr_scales["friction_scale"].clone()
|
||||||
|
runner.reset()
|
||||||
|
assert not torch.equal(before, runner._dr_scales["friction_scale"])
|
||||||
|
|
||||||
|
delays = runner._dr_delay
|
||||||
|
assert torch.all(delays >= 0) and torch.all(delays <= 2)
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dr_disabled_is_noop():
|
||||||
|
runner = _runner(num_envs=2, domain_rand={})
|
||||||
|
runner.reset()
|
||||||
|
for vals in runner._dr_scales.values():
|
||||||
|
assert torch.all(vals == 1.0)
|
||||||
|
assert runner._max_delay == 0
|
||||||
|
assert runner._qpos_noise_std == 0.0
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_delay_buffer_returns_lagged_action():
|
||||||
|
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
|
||||||
|
runner.reset()
|
||||||
|
runner._dr_delay = torch.tensor([0, 1, 2])
|
||||||
|
runner._action_buf.zero_()
|
||||||
|
|
||||||
|
a1 = torch.tensor([[1.0], [1.0], [1.0]])
|
||||||
|
a2 = torch.tensor([[2.0], [2.0], [2.0]])
|
||||||
|
a3 = torch.tensor([[3.0], [3.0], [3.0]])
|
||||||
|
|
||||||
|
d1 = runner._apply_action_delay(a1)
|
||||||
|
d2 = runner._apply_action_delay(a2)
|
||||||
|
d3 = runner._apply_action_delay(a3)
|
||||||
|
|
||||||
|
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
|
||||||
|
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
|
||||||
|
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Initial-state randomization ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_wide_pendulum_init_actually_applied():
|
||||||
|
runner = _runner(num_envs=32)
|
||||||
|
qpos, _ = runner._sim_reset(torch.arange(32))
|
||||||
|
pend_angles = qpos[:, 1]
|
||||||
|
# With ±180° init range, samples must spread far beyond the old ±0.05.
|
||||||
|
assert pend_angles.abs().max() > 1.0
|
||||||
|
assert pend_angles.std() > 0.5
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sim_reset_returns_full_batch():
|
||||||
|
runner = _runner(num_envs=4)
|
||||||
|
runner.reset()
|
||||||
|
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
|
||||||
|
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_mjx_runner_smoke():
|
||||||
|
pytest.importorskip("jax")
|
||||||
|
pytest.importorskip("mujoco.mjx")
|
||||||
|
from src.runners.mjx import MJXRunner, MJXRunnerConfig
|
||||||
|
|
||||||
|
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||||
|
runner = MJXRunner(
|
||||||
|
env=env,
|
||||||
|
config=MJXRunnerConfig(
|
||||||
|
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
raw_dim = env.observation_space.shape[0]
|
||||||
|
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
actions = torch.rand(4, 1) * 2 - 1
|
||||||
|
obs, rewards, term, trunc, _ = runner.step(actions)
|
||||||
|
assert torch.isfinite(obs).all()
|
||||||
|
assert torch.isfinite(rewards).all()
|
||||||
|
|
||||||
|
# Wide pendulum init must reach MJX resets too.
|
||||||
|
qpos, qvel = runner._sim_reset(torch.arange(4))
|
||||||
|
assert qpos.shape[0] == 4 # full batch
|
||||||
|
runner.close()
|
||||||
Reference in New Issue
Block a user