♻️ full agent refactor

This commit is contained in:
2026-06-10 21:15:34 +02:00
parent a98e86ef66
commit 1e0836e1bc
49 changed files with 1309 additions and 829 deletions

7
tests/conftest.py Normal file
View File

@@ -0,0 +1,7 @@
import sys
from pathlib import Path
# Make `src.*` importable regardless of pytest invocation directory.
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)

79
tests/test_reward.py Normal file
View File

@@ -0,0 +1,79 @@
"""Reward design tests — balancing must strictly dominate spinning."""
import math
from pathlib import Path
import pytest
import torch
from src.envs.rotary_cartpole import (
RotaryCartPoleConfig,
RotaryCartPoleEnv,
RotaryCartPoleState,
)
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
def _env() -> RotaryCartPoleEnv:
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
t = lambda v: torch.tensor([float(v)])
return RotaryCartPoleState(
motor_angle=t(motor), motor_vel=t(motor_vel),
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
)
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
a = torch.tensor([[float(action)]])
pa = torch.tensor([[float(prev_action)]])
return float(env.compute_rewards(state, a, pa)[0])
def test_balancing_beats_spinning_through_upright():
env = _env()
balanced = _state(pend=math.pi, pend_vel=0.0)
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
def test_average_spin_cycle_reward_below_balance():
"""Mean reward over a full revolution at high speed << balanced reward."""
env = _env()
angles = torch.linspace(0, 2 * math.pi, 32)
spin_rewards = [
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
]
mean_spin = sum(spin_rewards) / len(spin_rewards)
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
assert balanced > 3.0 * mean_spin
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
env = _env()
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
assert _reward(env, over_limit) == -10.0
assert bool(env.compute_terminations(over_limit)[0])
def test_action_rate_penalty_reduces_reward():
env = _env()
s = _state(pend=math.pi)
smooth = _reward(env, s, action=0.5, prev_action=0.5)
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
assert smooth > jerky
assert smooth - jerky == pytest.approx(
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
)
def test_initial_state_ranges_widen_pendulum_only():
env = _env()
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()

125
tests/test_robot_config.py Normal file
View File

@@ -0,0 +1,125 @@
"""Robot config loading + motor model unit tests."""
import math
from pathlib import Path
import pytest
import torch
from src.core.robot import ActuatorConfig, load_robot_config
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
# ── Loading ──────────────────────────────────────────────────────────
def test_canonical_robot_yaml_loads_full_motor_model():
robot = load_robot_config(ROBOT_DIR)
act = robot.actuators[0]
assert act.has_motor_model
# Tuned (unified sysid) values must survive the round-trip.
assert act.filter_tau == pytest.approx(0.096263)
assert act.stribeck_friction_boost == pytest.approx(0.068594)
assert act.stribeck_vel == pytest.approx(5.279594)
assert act.action_bias == pytest.approx(0.056566)
assert act.gear == pytest.approx((0.846499, 1.183733))
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
(tmp_path / "robot.yaml").write_text(
"urdf: dummy.urdf\n"
"actuators:\n"
" - joint: j\n"
" type: motor\n"
" gear: [1.0, 1.0]\n"
" some_future_field: 42\n"
)
robot = load_robot_config(tmp_path) # must not raise
assert robot.actuators[0].joint == "j"
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
@pytest.fixture
def act() -> ActuatorConfig:
return ActuatorConfig(
joint="m",
gear=(0.8, 1.2),
ctrl_range=(-0.6, 0.6),
deadzone=(0.15, 0.20),
frictionloss=(0.014, 0.001),
damping=(0.013, 0.015),
stribeck_friction_boost=0.07,
stribeck_vel=5.0,
action_bias=0.05,
)
def test_transform_ctrl_clips_to_ctrl_range(act):
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
out = act.transform_ctrl(1.0)
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
assert act.transform_ctrl(0.05) == 0.0
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
assert act.transform_ctrl(-0.15) == 0.0
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
assert pos == pytest.approx(0.55 * 0.8)
assert neg == pytest.approx(-0.45 * 1.2)
def test_transform_action_matches_transform_ctrl_elementwise(act):
vals = torch.linspace(-1.2, 1.2, 49)
batched = act.transform_action(vals.clone())
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
assert torch.allclose(batched, scalar, atol=1e-6)
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
def test_friction_opposes_motion(act):
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
def test_stribeck_boost_decays_with_speed(act):
"""Friction torque magnitude (minus damping) is higher near standstill."""
no_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
)
with_strb = ActuatorConfig(
joint="m", gear=act.gear, frictionloss=act.frictionloss,
damping=(0.0, 0.0),
stribeck_friction_boost=0.07, stribeck_vel=5.0,
)
v_slow, v_fast = 0.1, 50.0
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
no_strb.compute_motor_force(v_slow, 0.0))
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
no_strb.compute_motor_force(v_fast, 0.0))
assert extra_slow == pytest.approx(
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
assert extra_fast < extra_slow
assert extra_fast == pytest.approx(0.0, abs=1e-9)
def test_friction_scale_dr_multiplies_friction(act):
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
# damping_scale=0 isolates the friction term
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
assert doubled == pytest.approx(2.0 * base)

173
tests/test_runner.py Normal file
View File

@@ -0,0 +1,173 @@
"""Runner integration tests — DR, history, action delay, init randomization.
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
is skipped when JAX isn't installed.
"""
import math
from pathlib import Path
import pytest
import torch
from src.core.registry import build_env
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
DR = {
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
}
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
cfg = MuJoCoRunnerConfig(
num_envs=num_envs,
device="cpu",
history_length=history_length,
domain_rand=domain_rand or {},
)
return MuJoCoRunner(env=env, config=cfg)
# ── Observation layout ───────────────────────────────────────────────
def test_obs_is_raw_plus_history():
runner = _runner(num_envs=2, history_length=3)
raw_dim = runner.env.observation_space.shape[0] # 6
step_dim = raw_dim + 1 # + action
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
obs, _ = runner.reset()
assert obs.shape == (2, raw_dim + 3 * step_dim)
# Fresh history must be zero.
assert torch.all(obs[:, raw_dim:] == 0)
actions = torch.full((2, 1), 0.3)
obs, rewards, term, trunc, info = runner.step(actions)
assert obs.shape == (2, raw_dim + 3 * step_dim)
assert rewards.shape == (2, 1)
# Newest history slot holds the commanded action.
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
runner.close()
def test_no_history_keeps_plain_obs():
runner = _runner(num_envs=2, history_length=0)
assert runner.observation_space.shape[0] == 6
runner.close()
# ── Domain randomization ─────────────────────────────────────────────
def test_dr_scales_sampled_within_ranges_and_resampled():
runner = _runner(num_envs=16, domain_rand=DR)
runner.reset()
for name, (lo, hi) in (
("friction_scale", (0.6, 1.6)),
("damping_scale", (0.6, 1.6)),
("torque_scale", (0.85, 1.15)),
):
vals = runner._dr_scales[name]
assert torch.all(vals >= lo) and torch.all(vals <= hi)
# 16 independent uniform samples are never all identical.
assert vals.std() > 0
before = runner._dr_scales["friction_scale"].clone()
runner.reset()
assert not torch.equal(before, runner._dr_scales["friction_scale"])
delays = runner._dr_delay
assert torch.all(delays >= 0) and torch.all(delays <= 2)
runner.close()
def test_dr_disabled_is_noop():
runner = _runner(num_envs=2, domain_rand={})
runner.reset()
for vals in runner._dr_scales.values():
assert torch.all(vals == 1.0)
assert runner._max_delay == 0
assert runner._qpos_noise_std == 0.0
runner.close()
def test_action_delay_buffer_returns_lagged_action():
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
runner.reset()
runner._dr_delay = torch.tensor([0, 1, 2])
runner._action_buf.zero_()
a1 = torch.tensor([[1.0], [1.0], [1.0]])
a2 = torch.tensor([[2.0], [2.0], [2.0]])
a3 = torch.tensor([[3.0], [3.0], [3.0]])
d1 = runner._apply_action_delay(a1)
d2 = runner._apply_action_delay(a2)
d3 = runner._apply_action_delay(a3)
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
runner.close()
# ── Initial-state randomization ──────────────────────────────────────
def test_wide_pendulum_init_actually_applied():
runner = _runner(num_envs=32)
qpos, _ = runner._sim_reset(torch.arange(32))
pend_angles = qpos[:, 1]
# With ±180° init range, samples must spread far beyond the old ±0.05.
assert pend_angles.abs().max() > 1.0
assert pend_angles.std() > 0.5
runner.close()
def test_sim_reset_returns_full_batch():
runner = _runner(num_envs=4)
runner.reset()
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
runner.close()
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
def test_mjx_runner_smoke():
pytest.importorskip("jax")
pytest.importorskip("mujoco.mjx")
from src.runners.mjx import MJXRunner, MJXRunnerConfig
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
runner = MJXRunner(
env=env,
config=MJXRunnerConfig(
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
),
)
obs, _ = runner.reset()
raw_dim = env.observation_space.shape[0]
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
for _ in range(3):
actions = torch.rand(4, 1) * 2 - 1
obs, rewards, term, trunc, _ = runner.step(actions)
assert torch.isfinite(obs).all()
assert torch.isfinite(rewards).all()
# Wide pendulum init must reach MJX resets too.
qpos, qvel = runner._sim_reset(torch.arange(4))
assert qpos.shape[0] == 4 # full batch
runner.close()

View File