"""Reward design tests — balancing must strictly dominate spinning.""" import math from pathlib import Path import pytest import torch from src.envs.rotary_cartpole import ( RotaryCartPoleConfig, RotaryCartPoleEnv, RotaryCartPoleState, ) ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole") def _env() -> RotaryCartPoleEnv: return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH)) def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState: t = lambda v: torch.tensor([float(v)]) return RotaryCartPoleState( motor_angle=t(motor), motor_vel=t(motor_vel), pendulum_angle=t(pend), pendulum_vel=t(pend_vel), ) def _reward(env, state, action=0.0, prev_action=0.0) -> float: a = torch.tensor([[float(action)]]) pa = torch.tensor([[float(prev_action)]]) return float(env.compute_rewards(state, a, pa)[0]) def test_balancing_beats_spinning_through_upright(): env = _env() balanced = _state(pend=math.pi, pend_vel=0.0) spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top assert _reward(env, balanced) > 2.0 * _reward(env, spinning) def test_average_spin_cycle_reward_below_balance(): """Mean reward over a full revolution at high speed << balanced reward.""" env = _env() angles = torch.linspace(0, 2 * math.pi, 32) spin_rewards = [ _reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles ] mean_spin = sum(spin_rewards) / len(spin_rewards) balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0)) assert balanced > 3.0 * mean_spin def test_motor_limit_violation_is_heavily_penalised_and_terminates(): env = _env() over_limit = _state(motor=math.radians(95.0), pend=math.pi) assert _reward(env, over_limit) == -10.0 assert bool(env.compute_terminations(over_limit)[0]) def test_action_rate_penalty_reduces_reward(): env = _env() s = _state(pend=math.pi) smooth = _reward(env, s, action=0.5, prev_action=0.5) jerky = _reward(env, s, action=0.5, prev_action=-0.5) assert smooth > jerky assert smooth - jerky == pytest.approx( env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6, ) def test_initial_state_ranges_widen_pendulum_only(): env = _env() qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2) assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05 assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0) assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0) assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()