86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
import dataclasses
|
|
import torch
|
|
from src.core.env import BaseEnv, BaseEnvConfig
|
|
from gymnasium import spaces
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RotaryCartPoleState:
|
|
motor_angle: torch.Tensor # (num_envs,)
|
|
motor_vel: torch.Tensor # (num_envs,)
|
|
pendulum_angle: torch.Tensor # (num_envs,)
|
|
pendulum_vel: torch.Tensor # (num_envs,)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RotaryCartPoleConfig(BaseEnvConfig):
|
|
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
|
|
|
The motor rotates the arm horizontally; the pendulum swings freely
|
|
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
|
"""
|
|
# Reward shaping
|
|
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
|
|
|
|
|
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|
"""Furuta pendulum / rotary inverted pendulum environment.
|
|
|
|
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
|
|
|
Observations (6):
|
|
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
|
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
|
|
|
Actions (1):
|
|
Torque applied to the motor_joint.
|
|
"""
|
|
|
|
def __init__(self, config: RotaryCartPoleConfig):
|
|
super().__init__(config)
|
|
|
|
@property
|
|
def observation_space(self) -> spaces.Space:
|
|
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
|
|
|
@property
|
|
def action_space(self) -> spaces.Space:
|
|
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
|
|
|
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
|
return RotaryCartPoleState(
|
|
motor_angle=qpos[:, 0],
|
|
motor_vel=qvel[:, 0],
|
|
pendulum_angle=qpos[:, 1],
|
|
pendulum_vel=qvel[:, 1],
|
|
)
|
|
|
|
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
|
return torch.stack([
|
|
torch.sin(state.motor_angle),
|
|
torch.cos(state.motor_angle),
|
|
torch.sin(state.pendulum_angle),
|
|
torch.cos(state.pendulum_angle),
|
|
state.motor_vel,
|
|
state.pendulum_vel,
|
|
], dim=-1)
|
|
|
|
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
|
# Upright reward: -cos(θ) ∈ [-1, +1]
|
|
upright = -torch.cos(state.pendulum_angle)
|
|
|
|
# Velocity penalties — make spinning expensive but allow swing-up
|
|
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
|
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
|
|
|
return upright - pend_vel_penalty - motor_vel_penalty
|
|
|
|
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
|
# No early termination — episode runs for max_steps (truncation only).
|
|
# The agent must learn to swing up AND balance continuously.
|
|
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
|
|
|
def get_default_qpos(self, nq: int) -> list[float] | None:
|
|
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
|
|
return None
|