"""Runner integration tests — DR, history, action delay, init randomization. Uses the CPU MuJoCo runner (small env counts). MJX gets a smoke test that is skipped when JAX isn't installed. """ import math from pathlib import Path import pytest import torch from src.core.registry import build_env from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole") DR = { "qpos_noise_std": 0.01, "qvel_noise_std": 0.5, "action_delay_steps": [0, 2], "friction_scale": [0.6, 1.6], "damping_scale": [0.6, 1.6], "torque_scale": [0.85, 1.15], } def _runner(num_envs=4, history_length=3, domain_rand=None) -> MuJoCoRunner: env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH)) cfg = MuJoCoRunnerConfig( num_envs=num_envs, device="cpu", history_length=history_length, domain_rand=domain_rand or {}, ) return MuJoCoRunner(env=env, config=cfg) # ── Observation layout ─────────────────────────────────────────────── def test_obs_is_raw_plus_history(): runner = _runner(num_envs=2, history_length=3) raw_dim = runner.env.observation_space.shape[0] # 6 step_dim = raw_dim + 1 # + action assert runner.observation_space.shape[0] == raw_dim + 3 * step_dim obs, _ = runner.reset() assert obs.shape == (2, raw_dim + 3 * step_dim) # Fresh history must be zero. assert torch.all(obs[:, raw_dim:] == 0) actions = torch.full((2, 1), 0.3) obs, rewards, term, trunc, info = runner.step(actions) assert obs.shape == (2, raw_dim + 3 * step_dim) assert rewards.shape == (2, 1) # Newest history slot holds the commanded action. assert torch.allclose(obs[:, -1], torch.full((2,), 0.3)) runner.close() def test_no_history_keeps_plain_obs(): runner = _runner(num_envs=2, history_length=0) assert runner.observation_space.shape[0] == 6 runner.close() # ── Domain randomization ───────────────────────────────────────────── def test_dr_scales_sampled_within_ranges_and_resampled(): runner = _runner(num_envs=16, domain_rand=DR) runner.reset() for name, (lo, hi) in ( ("friction_scale", (0.6, 1.6)), ("damping_scale", (0.6, 1.6)), ("torque_scale", (0.85, 1.15)), ): vals = runner._dr_scales[name] assert torch.all(vals >= lo) and torch.all(vals <= hi) # 16 independent uniform samples are never all identical. assert vals.std() > 0 before = runner._dr_scales["friction_scale"].clone() runner.reset() assert not torch.equal(before, runner._dr_scales["friction_scale"]) delays = runner._dr_delay assert torch.all(delays >= 0) and torch.all(delays <= 2) runner.close() def test_dr_disabled_is_noop(): runner = _runner(num_envs=2, domain_rand={}) runner.reset() for vals in runner._dr_scales.values(): assert torch.all(vals == 1.0) assert runner._max_delay == 0 assert runner._qpos_noise_std == 0.0 runner.close() def test_action_delay_buffer_returns_lagged_action(): runner = _runner(num_envs=3, domain_rand={"action_delay_steps": [0, 2]}) runner.reset() runner._dr_delay = torch.tensor([0, 1, 2]) runner._action_buf.zero_() a1 = torch.tensor([[1.0], [1.0], [1.0]]) a2 = torch.tensor([[2.0], [2.0], [2.0]]) a3 = torch.tensor([[3.0], [3.0], [3.0]]) d1 = runner._apply_action_delay(a1) d2 = runner._apply_action_delay(a2) d3 = runner._apply_action_delay(a3) assert d1.squeeze(-1).tolist() == [1.0, 0.0, 0.0] assert d2.squeeze(-1).tolist() == [2.0, 1.0, 0.0] assert d3.squeeze(-1).tolist() == [3.0, 2.0, 1.0] runner.close() # ── Initial-state randomization ────────────────────────────────────── def test_wide_pendulum_init_actually_applied(): runner = _runner(num_envs=32) qpos, _ = runner._sim_reset(torch.arange(32)) pend_angles = qpos[:, 1] # With ±180° init range, samples must spread far beyond the old ±0.05. assert pend_angles.abs().max() > 1.0 assert pend_angles.std() > 0.5 runner.close() def test_sim_reset_returns_full_batch(): runner = _runner(num_envs=4) runner.reset() qpos, qvel = runner._sim_reset(torch.tensor([1])) # reset one env only assert qpos.shape == (4, 2) and qvel.shape == (4, 2) runner.close() # ── MJX smoke (skipped without JAX) ────────────────────────────────── def test_mjx_runner_smoke(): pytest.importorskip("jax") pytest.importorskip("mujoco.mjx") from src.runners.mjx import MJXRunner, MJXRunnerConfig env = RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH)) runner = MJXRunner( env=env, config=MJXRunnerConfig( num_envs=4, device="cpu", history_length=3, domain_rand=DR, ), ) obs, _ = runner.reset() raw_dim = env.observation_space.shape[0] assert obs.shape == (4, raw_dim + 3 * (raw_dim + 1)) for _ in range(3): actions = torch.rand(4, 1) * 2 - 1 obs, rewards, term, trunc, _ = runner.step(actions) assert torch.isfinite(obs).all() assert torch.isfinite(rewards).all() # Wide pendulum init must reach MJX resets too. qpos, qvel = runner._sim_reset(torch.arange(4)) assert qpos.shape[0] == 4 # full batch runner.close()