Files
RL-Sim-Framework/src/envs/rotary_cartpole.py
2026-03-09 20:39:02 +01:00

86 lines
3.0 KiB
Python

import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class RotaryCartPoleState:
motor_angle: torch.Tensor # (num_envs,)
motor_vel: torch.Tensor # (num_envs,)
pendulum_angle: torch.Tensor # (num_envs,)
pendulum_vel: torch.Tensor # (num_envs,)
@dataclasses.dataclass
class RotaryCartPoleConfig(BaseEnvConfig):
"""Rotary inverted pendulum (Furuta pendulum) task config.
The motor rotates the arm horizontally; the pendulum swings freely
at the arm tip. Goal: swing the pendulum up and balance it upright.
"""
# Reward shaping
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
"""Furuta pendulum / rotary inverted pendulum environment.
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
Observations (6):
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
Using sin/cos avoids discontinuities at ±π for continuous joints.
Actions (1):
Torque applied to the motor_joint.
"""
def __init__(self, config: RotaryCartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> spaces.Space:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
@property
def action_space(self) -> spaces.Space:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
return RotaryCartPoleState(
motor_angle=qpos[:, 0],
motor_vel=qvel[:, 0],
pendulum_angle=qpos[:, 1],
pendulum_vel=qvel[:, 1],
)
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
return torch.stack([
torch.sin(state.motor_angle),
torch.cos(state.motor_angle),
torch.sin(state.pendulum_angle),
torch.cos(state.pendulum_angle),
state.motor_vel,
state.pendulum_vel,
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
# Upright reward: -cos(θ) ∈ [-1, +1]
upright = -torch.cos(state.pendulum_angle)
# Velocity penalties — make spinning expensive but allow swing-up
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
motor_vel_penalty = 0.01 * state.motor_vel ** 2
return upright - pend_vel_penalty - motor_vel_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
# The agent must learn to swing up AND balance continuously.
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
return None