♻️ full agent refactor
This commit is contained in:
7
tests/conftest.py
Normal file
7
tests/conftest.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Make `src.*` importable regardless of pytest invocation directory.
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
79
tests/test_reward.py
Normal file
79
tests/test_reward.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Reward design tests — balancing must strictly dominate spinning."""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.envs.rotary_cartpole import (
|
||||
RotaryCartPoleConfig,
|
||||
RotaryCartPoleEnv,
|
||||
RotaryCartPoleState,
|
||||
)
|
||||
|
||||
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||
|
||||
|
||||
def _env() -> RotaryCartPoleEnv:
|
||||
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
|
||||
|
||||
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
|
||||
t = lambda v: torch.tensor([float(v)])
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=t(motor), motor_vel=t(motor_vel),
|
||||
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
|
||||
)
|
||||
|
||||
|
||||
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
|
||||
a = torch.tensor([[float(action)]])
|
||||
pa = torch.tensor([[float(prev_action)]])
|
||||
return float(env.compute_rewards(state, a, pa)[0])
|
||||
|
||||
|
||||
def test_balancing_beats_spinning_through_upright():
|
||||
env = _env()
|
||||
balanced = _state(pend=math.pi, pend_vel=0.0)
|
||||
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
|
||||
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
|
||||
|
||||
|
||||
def test_average_spin_cycle_reward_below_balance():
|
||||
"""Mean reward over a full revolution at high speed << balanced reward."""
|
||||
env = _env()
|
||||
angles = torch.linspace(0, 2 * math.pi, 32)
|
||||
spin_rewards = [
|
||||
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
|
||||
]
|
||||
mean_spin = sum(spin_rewards) / len(spin_rewards)
|
||||
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
|
||||
assert balanced > 3.0 * mean_spin
|
||||
|
||||
|
||||
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
|
||||
env = _env()
|
||||
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
|
||||
assert _reward(env, over_limit) == -10.0
|
||||
assert bool(env.compute_terminations(over_limit)[0])
|
||||
|
||||
|
||||
def test_action_rate_penalty_reduces_reward():
|
||||
env = _env()
|
||||
s = _state(pend=math.pi)
|
||||
smooth = _reward(env, s, action=0.5, prev_action=0.5)
|
||||
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
|
||||
assert smooth > jerky
|
||||
assert smooth - jerky == pytest.approx(
|
||||
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_initial_state_ranges_widen_pendulum_only():
|
||||
env = _env()
|
||||
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
|
||||
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
|
||||
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()
|
||||
125
tests/test_robot_config.py
Normal file
125
tests/test_robot_config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Robot config loading + motor model unit tests."""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.core.robot import ActuatorConfig, load_robot_config
|
||||
|
||||
ROBOT_DIR = Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole"
|
||||
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_canonical_robot_yaml_loads_full_motor_model():
|
||||
robot = load_robot_config(ROBOT_DIR)
|
||||
act = robot.actuators[0]
|
||||
|
||||
assert act.has_motor_model
|
||||
# Tuned (unified sysid) values must survive the round-trip.
|
||||
assert act.filter_tau == pytest.approx(0.096263)
|
||||
assert act.stribeck_friction_boost == pytest.approx(0.068594)
|
||||
assert act.stribeck_vel == pytest.approx(5.279594)
|
||||
assert act.action_bias == pytest.approx(0.056566)
|
||||
assert act.gear == pytest.approx((0.846499, 1.183733))
|
||||
|
||||
|
||||
def test_unknown_actuator_keys_are_ignored_not_fatal(tmp_path):
|
||||
(tmp_path / "dummy.urdf").write_text("<robot name='x'/>")
|
||||
(tmp_path / "robot.yaml").write_text(
|
||||
"urdf: dummy.urdf\n"
|
||||
"actuators:\n"
|
||||
" - joint: j\n"
|
||||
" type: motor\n"
|
||||
" gear: [1.0, 1.0]\n"
|
||||
" some_future_field: 42\n"
|
||||
)
|
||||
robot = load_robot_config(tmp_path) # must not raise
|
||||
assert robot.actuators[0].joint == "j"
|
||||
|
||||
|
||||
# ── transform_ctrl: clip → bias → deadzone → gear ────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def act() -> ActuatorConfig:
|
||||
return ActuatorConfig(
|
||||
joint="m",
|
||||
gear=(0.8, 1.2),
|
||||
ctrl_range=(-0.6, 0.6),
|
||||
deadzone=(0.15, 0.20),
|
||||
frictionloss=(0.014, 0.001),
|
||||
damping=(0.013, 0.015),
|
||||
stribeck_friction_boost=0.07,
|
||||
stribeck_vel=5.0,
|
||||
action_bias=0.05,
|
||||
)
|
||||
|
||||
|
||||
def test_transform_ctrl_clips_to_ctrl_range(act):
|
||||
# 1.0 clips to 0.6, then +bias=0.65, gear comp 0.8/1.0 → 0.52
|
||||
out = act.transform_ctrl(1.0)
|
||||
assert out == pytest.approx((0.6 + 0.05) * 0.8 / 1.0)
|
||||
|
||||
|
||||
def test_transform_ctrl_deadzone_zeroes_small_commands(act):
|
||||
# 0.05 + bias 0.05 = 0.10 < dz_pos 0.15 → 0
|
||||
assert act.transform_ctrl(0.05) == 0.0
|
||||
# -0.15 + bias 0.05 = -0.10 > -dz_neg -0.20 → 0
|
||||
assert act.transform_ctrl(-0.15) == 0.0
|
||||
|
||||
|
||||
def test_transform_ctrl_gear_compensation_is_asymmetric(act):
|
||||
pos = act.transform_ctrl(0.5) # (0.55) * 0.8
|
||||
neg = act.transform_ctrl(-0.5) # (-0.45) * 1.2
|
||||
assert pos == pytest.approx(0.55 * 0.8)
|
||||
assert neg == pytest.approx(-0.45 * 1.2)
|
||||
|
||||
|
||||
def test_transform_action_matches_transform_ctrl_elementwise(act):
|
||||
vals = torch.linspace(-1.2, 1.2, 49)
|
||||
batched = act.transform_action(vals.clone())
|
||||
scalar = torch.tensor([act.transform_ctrl(float(v)) for v in vals])
|
||||
assert torch.allclose(batched, scalar, atol=1e-6)
|
||||
|
||||
|
||||
# ── compute_motor_force: Coulomb + Stribeck + damping ────────────────
|
||||
|
||||
|
||||
def test_friction_opposes_motion(act):
|
||||
assert act.compute_motor_force(vel=2.0, ctrl=0.0) < 0
|
||||
assert act.compute_motor_force(vel=-2.0, ctrl=0.0) > 0
|
||||
assert act.compute_motor_force(vel=0.0, ctrl=0.0) == 0.0
|
||||
|
||||
|
||||
def test_stribeck_boost_decays_with_speed(act):
|
||||
"""Friction torque magnitude (minus damping) is higher near standstill."""
|
||||
no_strb = ActuatorConfig(
|
||||
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||
damping=(0.0, 0.0),
|
||||
)
|
||||
with_strb = ActuatorConfig(
|
||||
joint="m", gear=act.gear, frictionloss=act.frictionloss,
|
||||
damping=(0.0, 0.0),
|
||||
stribeck_friction_boost=0.07, stribeck_vel=5.0,
|
||||
)
|
||||
v_slow, v_fast = 0.1, 50.0
|
||||
extra_slow = abs(with_strb.compute_motor_force(v_slow, 0.0)) - abs(
|
||||
no_strb.compute_motor_force(v_slow, 0.0))
|
||||
extra_fast = abs(with_strb.compute_motor_force(v_fast, 0.0)) - abs(
|
||||
no_strb.compute_motor_force(v_fast, 0.0))
|
||||
|
||||
assert extra_slow == pytest.approx(
|
||||
0.07 * math.exp(-((v_slow / 5.0) ** 2)), abs=1e-9)
|
||||
assert extra_fast < extra_slow
|
||||
assert extra_fast == pytest.approx(0.0, abs=1e-9)
|
||||
|
||||
|
||||
def test_friction_scale_dr_multiplies_friction(act):
|
||||
base = act.compute_motor_force(1.0, 0.0, friction_scale=1.0, damping_scale=0.0)
|
||||
# damping_scale=0 isolates the friction term
|
||||
doubled = act.compute_motor_force(1.0, 0.0, friction_scale=2.0, damping_scale=0.0)
|
||||
assert doubled == pytest.approx(2.0 * base)
|
||||
173
tests/test_runner.py
Normal file
173
tests/test_runner.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Runner integration tests — DR, history, action delay, init randomization.
|
||||
|
||||
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
|
||||
is skipped when JAX isn't installed.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||
|
||||
DR = {
|
||||
"qpos_noise_std": 0.01,
|
||||
"qvel_noise_std": 0.5,
|
||||
"action_delay_steps": [0, 2],
|
||||
"friction_scale": [0.6, 1.6],
|
||||
"damping_scale": [0.6, 1.6],
|
||||
"torque_scale": [0.85, 1.15],
|
||||
}
|
||||
|
||||
|
||||
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
|
||||
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
cfg = MuJoCoRunnerConfig(
|
||||
num_envs=num_envs,
|
||||
device="cpu",
|
||||
history_length=history_length,
|
||||
domain_rand=domain_rand or {},
|
||||
)
|
||||
return MuJoCoRunner(env=env, config=cfg)
|
||||
|
||||
|
||||
# ── Observation layout ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_obs_is_raw_plus_history():
|
||||
runner = _runner(num_envs=2, history_length=3)
|
||||
raw_dim = runner.env.observation_space.shape[0] # 6
|
||||
step_dim = raw_dim + 1 # + action
|
||||
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
|
||||
|
||||
obs, _ = runner.reset()
|
||||
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||
# Fresh history must be zero.
|
||||
assert torch.all(obs[:, raw_dim:] == 0)
|
||||
|
||||
actions = torch.full((2, 1), 0.3)
|
||||
obs, rewards, term, trunc, info = runner.step(actions)
|
||||
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
||||
assert rewards.shape == (2, 1)
|
||||
# Newest history slot holds the commanded action.
|
||||
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_no_history_keeps_plain_obs():
|
||||
runner = _runner(num_envs=2, history_length=0)
|
||||
assert runner.observation_space.shape[0] == 6
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── Domain randomization ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_dr_scales_sampled_within_ranges_and_resampled():
|
||||
runner = _runner(num_envs=16, domain_rand=DR)
|
||||
runner.reset()
|
||||
for name, (lo, hi) in (
|
||||
("friction_scale", (0.6, 1.6)),
|
||||
("damping_scale", (0.6, 1.6)),
|
||||
("torque_scale", (0.85, 1.15)),
|
||||
):
|
||||
vals = runner._dr_scales[name]
|
||||
assert torch.all(vals >= lo) and torch.all(vals <= hi)
|
||||
# 16 independent uniform samples are never all identical.
|
||||
assert vals.std() > 0
|
||||
|
||||
before = runner._dr_scales["friction_scale"].clone()
|
||||
runner.reset()
|
||||
assert not torch.equal(before, runner._dr_scales["friction_scale"])
|
||||
|
||||
delays = runner._dr_delay
|
||||
assert torch.all(delays >= 0) and torch.all(delays <= 2)
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_dr_disabled_is_noop():
|
||||
runner = _runner(num_envs=2, domain_rand={})
|
||||
runner.reset()
|
||||
for vals in runner._dr_scales.values():
|
||||
assert torch.all(vals == 1.0)
|
||||
assert runner._max_delay == 0
|
||||
assert runner._qpos_noise_std == 0.0
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_action_delay_buffer_returns_lagged_action():
|
||||
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
|
||||
runner.reset()
|
||||
runner._dr_delay = torch.tensor([0, 1, 2])
|
||||
runner._action_buf.zero_()
|
||||
|
||||
a1 = torch.tensor([[1.0], [1.0], [1.0]])
|
||||
a2 = torch.tensor([[2.0], [2.0], [2.0]])
|
||||
a3 = torch.tensor([[3.0], [3.0], [3.0]])
|
||||
|
||||
d1 = runner._apply_action_delay(a1)
|
||||
d2 = runner._apply_action_delay(a2)
|
||||
d3 = runner._apply_action_delay(a3)
|
||||
|
||||
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
|
||||
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
|
||||
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── Initial-state randomization ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_wide_pendulum_init_actually_applied():
|
||||
runner = _runner(num_envs=32)
|
||||
qpos, _ = runner._sim_reset(torch.arange(32))
|
||||
pend_angles = qpos[:, 1]
|
||||
# With ±180° init range, samples must spread far beyond the old ±0.05.
|
||||
assert pend_angles.abs().max() > 1.0
|
||||
assert pend_angles.std() > 0.5
|
||||
runner.close()
|
||||
|
||||
|
||||
def test_sim_reset_returns_full_batch():
|
||||
runner = _runner(num_envs=4)
|
||||
runner.reset()
|
||||
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
|
||||
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
|
||||
runner.close()
|
||||
|
||||
|
||||
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
|
||||
|
||||
|
||||
def test_mjx_runner_smoke():
|
||||
pytest.importorskip("jax")
|
||||
pytest.importorskip("mujoco.mjx")
|
||||
from src.runners.mjx import MJXRunner, MJXRunnerConfig
|
||||
|
||||
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
runner = MJXRunner(
|
||||
env=env,
|
||||
config=MJXRunnerConfig(
|
||||
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
|
||||
),
|
||||
)
|
||||
obs, _ = runner.reset()
|
||||
raw_dim = env.observation_space.shape[0]
|
||||
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
|
||||
|
||||
for _ in range(3):
|
||||
actions = torch.rand(4, 1) * 2 - 1
|
||||
obs, rewards, term, trunc, _ = runner.step(actions)
|
||||
assert torch.isfinite(obs).all()
|
||||
assert torch.isfinite(rewards).all()
|
||||
|
||||
# Wide pendulum init must reach MJX resets too.
|
||||
qpos, qvel = runner._sim_reset(torch.arange(4))
|
||||
assert qpos.shape[0] == 4 # full batch
|
||||
runner.close()
|
||||
Reference in New Issue
Block a user