80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
"""Reward design tests — balancing must strictly dominate spinning."""
|
|
|
|
import math
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from src.envs.rotary_cartpole import (
|
|
RotaryCartPoleConfig,
|
|
RotaryCartPoleEnv,
|
|
RotaryCartPoleState,
|
|
)
|
|
|
|
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
|
|
|
|
|
def _env() -> RotaryCartPoleEnv:
|
|
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
|
|
|
|
|
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
|
|
t = lambda v: torch.tensor([float(v)])
|
|
return RotaryCartPoleState(
|
|
motor_angle=t(motor), motor_vel=t(motor_vel),
|
|
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
|
|
)
|
|
|
|
|
|
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
|
|
a = torch.tensor([[float(action)]])
|
|
pa = torch.tensor([[float(prev_action)]])
|
|
return float(env.compute_rewards(state, a, pa)[0])
|
|
|
|
|
|
def test_balancing_beats_spinning_through_upright():
|
|
env = _env()
|
|
balanced = _state(pend=math.pi, pend_vel=0.0)
|
|
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
|
|
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
|
|
|
|
|
|
def test_average_spin_cycle_reward_below_balance():
|
|
"""Mean reward over a full revolution at high speed << balanced reward."""
|
|
env = _env()
|
|
angles = torch.linspace(0, 2 * math.pi, 32)
|
|
spin_rewards = [
|
|
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
|
|
]
|
|
mean_spin = sum(spin_rewards) / len(spin_rewards)
|
|
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
|
|
assert balanced > 3.0 * mean_spin
|
|
|
|
|
|
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
|
|
env = _env()
|
|
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
|
|
assert _reward(env, over_limit) == -10.0
|
|
assert bool(env.compute_terminations(over_limit)[0])
|
|
|
|
|
|
def test_action_rate_penalty_reduces_reward():
|
|
env = _env()
|
|
s = _state(pend=math.pi)
|
|
smooth = _reward(env, s, action=0.5, prev_action=0.5)
|
|
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
|
|
assert smooth > jerky
|
|
assert smooth - jerky == pytest.approx(
|
|
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
|
|
)
|
|
|
|
|
|
def test_initial_state_ranges_widen_pendulum_only():
|
|
env = _env()
|
|
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
|
|
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
|
|
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
|
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
|
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()
|