174 lines
5.7 KiB
Python
174 lines
5.7 KiB
Python
"""Runner integration tests — DR, history, action delay, init randomization.
|
|
|
|
Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that
|
|
is skipped when JAX isn't installed.
|
|
"""
|
|
|
|
import math
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from src.core.registry import build_env
|
|
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
|
|
|
DR = {
|
|
"qpos_noise_std": 0.01,
|
|
"qvel_noise_std": 0.5,
|
|
"action_delay_steps": [0, 2],
|
|
"friction_scale": [0.6, 1.6],
|
|
"damping_scale": [0.6, 1.6],
|
|
"torque_scale": [0.85, 1.15],
|
|
}
|
|
|
|
|
|
def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner:
|
|
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
|
cfg = MuJoCoRunnerConfig(
|
|
num_envs=num_envs,
|
|
device="cpu",
|
|
history_length=history_length,
|
|
domain_rand=domain_rand or {},
|
|
)
|
|
return MuJoCoRunner(env=env, config=cfg)
|
|
|
|
|
|
# ── Observation layout ───────────────────────────────────────────────
|
|
|
|
|
|
def test_obs_is_raw_plus_history():
|
|
runner = _runner(num_envs=2, history_length=3)
|
|
raw_dim = runner.env.observation_space.shape[0] # 6
|
|
step_dim = raw_dim + 1 # + action
|
|
assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim
|
|
|
|
obs, _ = runner.reset()
|
|
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
|
# Fresh history must be zero.
|
|
assert torch.all(obs[:, raw_dim:] == 0)
|
|
|
|
actions = torch.full((2, 1), 0.3)
|
|
obs, rewards, term, trunc, info = runner.step(actions)
|
|
assert obs.shape == (2, raw_dim + 3 * step_dim)
|
|
assert rewards.shape == (2, 1)
|
|
# Newest history slot holds the commanded action.
|
|
assert torch.allclose(obs[:, -1], torch.full((2,), 0.3))
|
|
runner.close()
|
|
|
|
|
|
def test_no_history_keeps_plain_obs():
|
|
runner = _runner(num_envs=2, history_length=0)
|
|
assert runner.observation_space.shape[0] == 6
|
|
runner.close()
|
|
|
|
|
|
# ── Domain randomization ─────────────────────────────────────────────
|
|
|
|
|
|
def test_dr_scales_sampled_within_ranges_and_resampled():
|
|
runner = _runner(num_envs=16, domain_rand=DR)
|
|
runner.reset()
|
|
for name, (lo, hi) in (
|
|
("friction_scale", (0.6, 1.6)),
|
|
("damping_scale", (0.6, 1.6)),
|
|
("torque_scale", (0.85, 1.15)),
|
|
):
|
|
vals = runner._dr_scales[name]
|
|
assert torch.all(vals >= lo) and torch.all(vals <= hi)
|
|
# 16 independent uniform samples are never all identical.
|
|
assert vals.std() > 0
|
|
|
|
before = runner._dr_scales["friction_scale"].clone()
|
|
runner.reset()
|
|
assert not torch.equal(before, runner._dr_scales["friction_scale"])
|
|
|
|
delays = runner._dr_delay
|
|
assert torch.all(delays >= 0) and torch.all(delays <= 2)
|
|
runner.close()
|
|
|
|
|
|
def test_dr_disabled_is_noop():
|
|
runner = _runner(num_envs=2, domain_rand={})
|
|
runner.reset()
|
|
for vals in runner._dr_scales.values():
|
|
assert torch.all(vals == 1.0)
|
|
assert runner._max_delay == 0
|
|
assert runner._qpos_noise_std == 0.0
|
|
runner.close()
|
|
|
|
|
|
def test_action_delay_buffer_returns_lagged_action():
|
|
runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]})
|
|
runner.reset()
|
|
runner._dr_delay = torch.tensor([0, 1, 2])
|
|
runner._action_buf.zero_()
|
|
|
|
a1 = torch.tensor([[1.0], [1.0], [1.0]])
|
|
a2 = torch.tensor([[2.0], [2.0], [2.0]])
|
|
a3 = torch.tensor([[3.0], [3.0], [3.0]])
|
|
|
|
d1 = runner._apply_action_delay(a1)
|
|
d2 = runner._apply_action_delay(a2)
|
|
d3 = runner._apply_action_delay(a3)
|
|
|
|
assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0]
|
|
assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0]
|
|
assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0]
|
|
runner.close()
|
|
|
|
|
|
# ── Initial-state randomization ──────────────────────────────────────
|
|
|
|
|
|
def test_wide_pendulum_init_actually_applied():
|
|
runner = _runner(num_envs=32)
|
|
qpos, _ = runner._sim_reset(torch.arange(32))
|
|
pend_angles = qpos[:, 1]
|
|
# With ±180° init range, samples must spread far beyond the old ±0.05.
|
|
assert pend_angles.abs().max() > 1.0
|
|
assert pend_angles.std() > 0.5
|
|
runner.close()
|
|
|
|
|
|
def test_sim_reset_returns_full_batch():
|
|
runner = _runner(num_envs=4)
|
|
runner.reset()
|
|
qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only
|
|
assert qpos.shape == (4, 2) and qvel.shape == (4, 2)
|
|
runner.close()
|
|
|
|
|
|
# ── MJX smoke (skipped without JAX) ──────────────────────────────────
|
|
|
|
|
|
def test_mjx_runner_smoke():
|
|
pytest.importorskip("jax")
|
|
pytest.importorskip("mujoco.mjx")
|
|
from src.runners.mjx import MJXRunner, MJXRunnerConfig
|
|
|
|
env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
|
runner = MJXRunner(
|
|
env=env,
|
|
config=MJXRunnerConfig(
|
|
num_envs=4, device="cpu", history_length=3, domain_rand=DR,
|
|
),
|
|
)
|
|
obs, _ = runner.reset()
|
|
raw_dim = env.observation_space.shape[0]
|
|
assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1))
|
|
|
|
for _ in range(3):
|
|
actions = torch.rand(4, 1) * 2 - 1
|
|
obs, rewards, term, trunc, _ = runner.step(actions)
|
|
assert torch.isfinite(obs).all()
|
|
assert torch.isfinite(rewards).all()
|
|
|
|
# Wide pendulum init must reach MJX resets too.
|
|
qpos, qvel = runner._sim_reset(torch.arange(4))
|
|
assert qpos.shape[0] == 4 # full batch
|
|
runner.close()
|