♻️ full agent refactor
This commit is contained in:
79
tests/test_reward.py
Normal file
79
tests/test_reward.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Reward design tests — balancing must strictly dominate spinning."""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from src.envs.rotary_cartpole import (
|
||||
RotaryCartPoleConfig,
|
||||
RotaryCartPoleEnv,
|
||||
RotaryCartPoleState,
|
||||
)
|
||||
|
||||
ROBOT_PATH = str(Path(__file__).resolve().parent.parent / "assets" / "rotary_cartpole")
|
||||
|
||||
|
||||
def _env() -> RotaryCartPoleEnv:
|
||||
return RotaryCartPoleEnv(RotaryCartPoleConfig(robot_path=ROBOT_PATH))
|
||||
|
||||
|
||||
def _state(motor=0.0, motor_vel=0.0, pend=0.0, pend_vel=0.0) -> RotaryCartPoleState:
|
||||
t = lambda v: torch.tensor([float(v)])
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=t(motor), motor_vel=t(motor_vel),
|
||||
pendulum_angle=t(pend), pendulum_vel=t(pend_vel),
|
||||
)
|
||||
|
||||
|
||||
def _reward(env, state, action=0.0, prev_action=0.0) -> float:
|
||||
a = torch.tensor([[float(action)]])
|
||||
pa = torch.tensor([[float(prev_action)]])
|
||||
return float(env.compute_rewards(state, a, pa)[0])
|
||||
|
||||
|
||||
def test_balancing_beats_spinning_through_upright():
|
||||
env = _env()
|
||||
balanced = _state(pend=math.pi, pend_vel=0.0)
|
||||
spinning = _state(pend=math.pi, pend_vel=10.0) # full-speed loop at the top
|
||||
assert _reward(env, balanced) > 2.0 * _reward(env, spinning)
|
||||
|
||||
|
||||
def test_average_spin_cycle_reward_below_balance():
|
||||
"""Mean reward over a full revolution at high speed << balanced reward."""
|
||||
env = _env()
|
||||
angles = torch.linspace(0, 2 * math.pi, 32)
|
||||
spin_rewards = [
|
||||
_reward(env, _state(pend=float(a), pend_vel=10.0)) for a in angles
|
||||
]
|
||||
mean_spin = sum(spin_rewards) / len(spin_rewards)
|
||||
balanced = _reward(env, _state(pend=math.pi, pend_vel=0.0))
|
||||
assert balanced > 3.0 * mean_spin
|
||||
|
||||
|
||||
def test_motor_limit_violation_is_heavily_penalised_and_terminates():
|
||||
env = _env()
|
||||
over_limit = _state(motor=math.radians(95.0), pend=math.pi)
|
||||
assert _reward(env, over_limit) == -10.0
|
||||
assert bool(env.compute_terminations(over_limit)[0])
|
||||
|
||||
|
||||
def test_action_rate_penalty_reduces_reward():
|
||||
env = _env()
|
||||
s = _state(pend=math.pi)
|
||||
smooth = _reward(env, s, action=0.5, prev_action=0.5)
|
||||
jerky = _reward(env, s, action=0.5, prev_action=-0.5)
|
||||
assert smooth > jerky
|
||||
assert smooth - jerky == pytest.approx(
|
||||
env.config.action_rate_penalty * (0.5 - (-0.5)) ** 2, abs=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_initial_state_ranges_widen_pendulum_only():
|
||||
env = _env()
|
||||
qpos_lo, qpos_hi, qvel_lo, qvel_hi = env.initial_state_ranges(2, 2)
|
||||
assert qpos_lo[0] == -0.05 and qpos_hi[0] == 0.05
|
||||
assert qpos_lo[1] == -math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert qpos_hi[1] == math.pi * (env.config.pendulum_init_range_deg / 180.0)
|
||||
assert (qvel_lo == -0.05).all() and (qvel_hi == 0.05).all()
|
||||
Reference in New Issue
Block a user