update urdf and dependencies

This commit is contained in:
2026-03-09 20:39:02 +01:00
parent c753c369b4
commit 15da0ef2fd
11 changed files with 204 additions and 57 deletions

View File

@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
# height: sin(θ) → -1 (down) to +1 (up)
height = torch.sin(state.pendulum_angle)
# Upright reward: -cos(θ) ∈ [-1, +1]
upright = -torch.cos(state.pendulum_angle)
# Upright reward: strongly rewards being near vertical.
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
# down (h=-1): 0.0
# horiz (h= 0): 0.0
# up (h=+1): 1.0
# Only kicks in above horizontal, so swing-up isn't penalised.
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
# Velocity penalties — make spinning expensive but allow swing-up
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
motor_vel_penalty = 0.01 * state.motor_vel ** 2
# Motor effort penalty: small cost to avoid bang-bang control.
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
return 5.0 * upright_reward - effort_penalty
return upright - pend_vel_penalty - motor_vel_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
# The STL mesh is horizontal at qpos=0.
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
import math
return [0.0, -math.pi / 2]
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
return None