add rotary cartpole env

This commit is contained in:
2026-03-08 22:58:32 +01:00
parent c8f28ffbcc
commit c753c369b4
15 changed files with 464 additions and 171 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,105 @@
<?xml version="1.0" encoding="utf-8"?>
<robot name="rotary_cartpole">
<!-- Fixed world frame -->
<link name="world"/>
<!-- Base: motor housing, fixed to world -->
<link name="base_link">
<inertial>
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0"/>
<mass value="0.921"/>
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559"
ixy="0.0" iyz="-0.000149" ixz="6e-06"/>
</inertial>
<visual>
<origin xyz="0 0 0" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
</geometry>
</visual>
<collision>
<origin xyz="0 0 0" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
</geometry>
</collision>
</link>
<joint name="base_joint" type="fixed">
<parent link="world"/>
<child link="base_link"/>
</joint>
<!-- Arm: horizontal rotating arm driven by motor.
Real mass ~10g (Fusion assumed dense material, exported 279g). -->
<link name="arm">
<inertial>
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
<mass value="0.150"/>
<inertia ixx="4.05e-05" iyy="1.17e-05" izz="3.66e-05"
ixy="0.0" iyz="1.08e-07" ixz="0.0"/>
</inertial>
<visual>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
</geometry>
</visual>
<collision>
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
</geometry>
</collision>
</link>
<!-- Motor joint: base → arm, rotates around vertical z-axis -->
<joint name="motor_joint" type="revolute">
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0"/>
<parent link="base_link"/>
<child link="arm"/>
<axis xyz="0 0 1"/>
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0"/>
<dynamics damping="0.001"/>
</joint>
<!-- Pendulum: swings freely at the end of the arm.
Real mass: 5g pendulum + 10g weight at the tip (70mm from bearing) = 15g total.
(Fusion assumed dense material, exported 57g for the pendulum alone.) -->
<link name="pendulum">
<inertial>
<!-- Combined CoM: 5g rod (CoM ~35mm) + 10g tip weight at 70mm from pivot.
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
Inertia tensor rotated 45° to match diagonal rod axis. -->
<origin xyz="0.0583 -0.0583 0.0" rpy="0 0 0"/>
<mass value="0.015"/>
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
</inertial>
<visual>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
</geometry>
</visual>
<collision>
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
<geometry>
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
</geometry>
</collision>
</link>
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). -->
<joint name="pendulum_joint" type="continuous">
<origin xyz="0.000052 0.019274 0.014993" rpy="0 0 0"/>
<parent link="arm"/>
<child link="pendulum"/>
<axis xyz="0 -1 0"/>
<dynamics damping="0.0005"/>
</joint>
</robot>

7
configs/env/rotary_cartpole.yaml vendored Normal file
View File

@@ -0,0 +1,7 @@
max_steps: 1000
model_path: assets/rotary_cartpole/rotary_cartpole.urdf
reward_upright_scale: 1.0
actuators:
- joint: motor_joint
gear: 15.0
ctrl_range: [-1.0, 1.0]

View File

@@ -1,4 +1,5 @@
num_envs: 16
num_envs: 64
device: cpu
dt: 0.02
substeps: 2
dt: 0.002
substeps: 20
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)

View File

@@ -1,5 +1,5 @@
hidden_sizes: [128, 128]
total_timesteps: 1000000
total_timesteps: 5000000
rollout_steps: 1024
learning_epochs: 4
mini_batches: 4
@@ -8,6 +8,8 @@ gae_lambda: 0.95
learning_rate: 0.0003
clip_ratio: 0.2
value_loss_scale: 0.5
entropy_loss_scale: 0.01
entropy_loss_scale: 0.05
log_interval: 10
clearml_project: RL-Framework
# ClearML remote execution (GPU worker)
remote: false

View File

@@ -5,4 +5,6 @@ omegaconf
mujoco
skrl[torch]
clearml
imageio
imageio-ffmpeg
pytest

View File

@@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]):
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
return step_counts >= self.config.max_steps
def get_default_qpos(self, nq: int) -> list[float] | None:
"""Return the default joint positions for reset.
Override in subclass if the URDF zero pose doesn't match
the desired initial state (e.g. pendulum hanging down).
Returns None to use the URDF default (all zeros)."""
return None

View File

@@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]):
# skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
raise NotImplementedError("Render method not implemented for this runner.")
def render(self, env_idx: int = 0):
"""Offscreen render → RGB numpy array. Override in subclass."""
raise NotImplementedError("Render not implemented for this runner.")
def close(self) -> None:
self._sim_close()

View File

@@ -0,0 +1,94 @@
import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class RotaryCartPoleState:
motor_angle: torch.Tensor # (num_envs,)
motor_vel: torch.Tensor # (num_envs,)
pendulum_angle: torch.Tensor # (num_envs,)
pendulum_vel: torch.Tensor # (num_envs,)
@dataclasses.dataclass
class RotaryCartPoleConfig(BaseEnvConfig):
"""Rotary inverted pendulum (Furuta pendulum) task config.
The motor rotates the arm horizontally; the pendulum swings freely
at the arm tip. Goal: swing the pendulum up and balance it upright.
"""
# Reward shaping
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
"""Furuta pendulum / rotary inverted pendulum environment.
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
Observations (6):
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
Using sin/cos avoids discontinuities at ±π for continuous joints.
Actions (1):
Torque applied to the motor_joint.
"""
def __init__(self, config: RotaryCartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> spaces.Space:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
@property
def action_space(self) -> spaces.Space:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
return RotaryCartPoleState(
motor_angle=qpos[:, 0],
motor_vel=qvel[:, 0],
pendulum_angle=qpos[:, 1],
pendulum_vel=qvel[:, 1],
)
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
return torch.stack([
torch.sin(state.motor_angle),
torch.cos(state.motor_angle),
torch.sin(state.pendulum_angle),
torch.cos(state.pendulum_angle),
state.motor_vel,
state.pendulum_vel,
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
# height: sin(θ) → -1 (down) to +1 (up)
height = torch.sin(state.pendulum_angle)
# Upright reward: strongly rewards being near vertical.
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
# down (h=-1): 0.0
# horiz (h= 0): 0.0
# up (h=+1): 1.0
# Only kicks in above horizontal, so swing-up isn't penalised.
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
# Motor effort penalty: small cost to avoid bang-bang control.
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
return 5.0 * upright_reward - effort_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
# The agent must learn to swing up AND balance continuously.
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
# The STL mesh is horizontal at qpos=0.
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
import math
return [0.0, -math.pi / 2]

View File

@@ -4,7 +4,7 @@ from gymnasium import spaces
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
DeterministicMixin.__init__(self, clip_actions)

View File

@@ -1,12 +1,11 @@
import dataclasses
import tempfile
import os
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
import mujoco.viewer
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
@@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
# Step 1: Load URDF/MJCF as-is (no actuators)
model_raw = mujoco.MjModel.from_xml_path(model_path)
abs_path = os.path.abspath(model_path)
model_dir = os.path.dirname(abs_path)
is_urdf = abs_path.lower().endswith(".urdf")
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
# so we inject a <mujoco><compiler meshdir="..."/> block into a
# temporary copy. The original URDF stays clean and simulator-agnostic.
if is_urdf:
tree = ET.parse(abs_path)
root = tree.getroot()
# Detect the mesh subdirectory from the first mesh filename
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
dirname = os.path.dirname(fn)
if dirname:
meshdir = dirname
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
ET.SubElement(mj_ext, "compiler", attrib={
"meshdir": meshdir,
"balanceinertia": "true",
})
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
try:
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
finally:
os.unlink(tmp_urdf)
else:
model_raw = mujoco.MjModel.from_xml_path(abs_path)
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation
tmp_mjcf = tempfile.mktemp(suffix=".xml")
# Step 2: Export internal MJCF representation (save next to original
# model so relative mesh/asset paths resolve correctly on reload)
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
# at steady state, vel_max = torque / damping.
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Add damping to actuated joints to limit max speed and
# mimic real motor friction / back-EMF.
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
actuated_joints = {a.joint for a in actuators}
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("name") in actuated_joints:
jnt.set("damping", "0.05")
# Disable self-collision on all geoms.
# URDF mesh convex hulls often overlap at joints (especially
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
# causing phantom contact forces.
for geom in root.iter("geom"):
geom.set("contype", "0")
geom.set("conaffinity", "0")
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
with open(tmp_mjcf, "w") as f:
f.write(modified_xml)
return mujoco.MjModel.from_xml_path(tmp_mjcf)
finally:
import os
os.unlink(tmp_mjcf)
# Step 3: Inject actuators into the MJCF XML
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Step 4: Reload from modified MJCF
modified_xml = ET.tostring(root, encoding="unicode")
return mujoco.MjModel.from_xml_string(modified_xml)
if os.path.exists(tmp_mjcf):
os.unlink(tmp_mjcf)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._nq = self._model.nq
self._nv = self._model.nv
# Per-env smoothed ctrl state for EMA action filtering.
# Models real motor inertia: ctrl can't reverse instantly.
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
data.ctrl[:] = actions_np[i]
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
default_qpos = self.env.get_default_qpos(self._nq)
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Add small random perturbation so the pole doesn't start perfectly upright
# Set initial pose (env-specific, e.g. pendulum hanging down)
if default_qpos is not None:
data.qpos[:] = default_qpos
# Add small random perturbation for exploration
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Reset smoothed ctrl so motor starts from rest
self._smooth_ctrl[env_id][:] = 0.0
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
)
def _sim_close(self) -> None:
if hasattr(self, "_viewer") and self._viewer is not None:
self._viewer.close()
self._viewer = None
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
if mode == "human":
if not hasattr(self, "_viewer") or self._viewer is None:
self._viewer = mujoco.viewer.launch_passive(
self._model, self._data[env_idx]
)
# Update visual geometry from current physics state
mujoco.mj_forward(self._model, self._data[env_idx])
self._viewer.sync()
return None
elif mode == "rgb_array":
# Cache the offscreen renderer to avoid create/destroy overhead
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
return torch.from_numpy(pixels)
def render(self, env_idx: int = 0) -> np.ndarray | None:
"""Offscreen render → RGB numpy array (H, W, 3)."""
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
return self._offscreen_renderer.render().copy()

View File

@@ -4,19 +4,22 @@ import tempfile
from pathlib import Path
import numpy as np
import torch
import tqdm
from clearml import Logger
from gymnasium import spaces
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.memories.torch import RandomMemory
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.trainers.torch import SequentialTrainer
from src.core.runner import BaseRunner
from clearml import Task, Logger
import torch
from gymnasium import spaces
from skrl.memories.torch import RandomMemory
from src.models.mlp import SharedMLP
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.trainers.torch import SequentialTrainer
@dataclasses.dataclass
class TrainerConfig:
# PPO
rollout_steps: int = 2048
learning_epochs: int = 8
mini_batches: int = 4
@@ -29,30 +32,27 @@ class TrainerConfig:
hidden_sizes: tuple[int, ...] = (64, 64)
# Training
total_timesteps: int = 1_000_000
log_interval: int = 10
# Video recording
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
# Video recording (uploaded to ClearML)
record_video_every: int = 10_000 # 0 = disabled
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
clearml_project: str | None = None
clearml_task: str | None = None
# ── Video-recording trainer ──────────────────────────────────────────
class VideoRecordingTrainer(SequentialTrainer):
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
super().__init__(env=env, agents=agents, cfg=cfg)
self._trainer_config = trainer_config
self._tcfg = trainer_config
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
def single_agent_train(self) -> None:
"""Override to add periodic video recording."""
assert self.num_simultaneous_agents == 1
assert self.env.num_agents == 1
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
states, infos = self.env.reset()
@@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer):
disable=self.disable_progressbar,
file=sys.stdout,
):
# Pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
with torch.no_grad():
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
if not self.headless:
self.env.render()
self.agents.record_transition(
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
terminated=terminated,
truncated=truncated,
infos=infos,
timestep=timestep,
timesteps=self.timesteps,
states=states, actions=actions, rewards=rewards,
next_states=next_states, terminated=terminated,
truncated=truncated, infos=infos,
timestep=timestep, timesteps=self.timesteps,
)
if self.environment_info in infos:
@@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer):
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
# Reset environments
# Auto-reset for multi-env; single-env resets manually
if self.env.num_envs > 1:
states = next_states
else:
@@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer):
else:
states = next_states
# Record video at intervals
cfg = self._trainer_config
# Periodic video recording
if (
cfg
and cfg.record_video_every > 0
and (timestep + 1) % cfg.record_video_every == 0
self._tcfg
and self._tcfg.record_video_every > 0
and (timestep + 1) % self._tcfg.record_video_every == 0
):
self._record_video(timestep + 1)
def _get_video_fps(self) -> int:
"""Derive video fps from the simulation rate, or use configured value."""
cfg = self._trainer_config
if cfg.record_video_fps > 0:
return cfg.record_video_fps
# Auto-derive from runner's simulation parameters
runner = self.env
dt = getattr(runner.config, "dt", 0.02)
substeps = getattr(runner.config, "substeps", 1)
# ── helpers ───────────────────────────────────────────────────────
def _get_fps(self) -> int:
if self._tcfg and self._tcfg.record_video_fps > 0:
return self._tcfg.record_video_fps
dt = getattr(self.env.config, "dt", 0.02)
substeps = getattr(self.env.config, "substeps", 1)
return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None:
"""Record evaluation episodes and upload to ClearML."""
try:
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
except ImportError:
return
return
cfg = self._trainer_config
fps = self._get_video_fps()
min_frames = int(cfg.record_video_min_seconds * fps)
max_frames = min_frames * 3 # hard cap to prevent runaway recording
fps = self._get_fps()
max_steps = getattr(self.env.env.config, "max_steps", 500)
frames: list[np.ndarray] = []
while len(frames) < min_frames and len(frames) < max_frames:
obs, _ = self.env.reset()
done = False
steps = 0
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
while not done and steps < max_episode_steps:
with torch.no_grad():
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action)
frame = self.env.render(mode="rgb_array")
if frame is not None:
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
done = (terminated | truncated).any().item()
steps += 1
if len(frames) >= max_frames:
break
obs, _ = self.env.reset()
for _ in range(max_steps):
with torch.no_grad():
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action)
frame = self.env.render()
if frame is not None:
frames.append(frame)
if (terminated | truncated).any().item():
break
if frames:
video_path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(video_path, frames, fps=fps)
path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(path, frames, fps=fps)
logger = Logger.current_logger()
if logger:
logger.report_media(
title="Training Video",
series=f"step_{timestep}",
local_path=video_path,
iteration=timestep,
"Training Video", f"step_{timestep}",
local_path=path, iteration=timestep,
)
# Reset back to training state after recording
# Restore training state
self.env.reset()
# ── Main trainer ─────────────────────────────────────────────────────
class Trainer:
def __init__(self, runner: BaseRunner, config: TrainerConfig):
self.runner = runner
self.config = config
self._init_clearml()
self._init_agent()
def _init_clearml(self) -> None:
if self.config.clearml_project and self.config.clearml_task:
self.clearml_task = Task.init(
project_name=self.config.clearml_project,
task_name=self.config.clearml_task,
)
else:
self.clearml_task = None
def _init_agent(self) -> None:
device: torch.device = self.runner.device
obs_space: spaces.Space = self.runner.observation_space
act_space: spaces.Space = self.runner.action_space
num_envs: int = self.runner.num_envs
device = self.runner.device
obs_space = self.runner.observation_space
act_space = self.runner.action_space
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
self.memory = RandomMemory(
memory_size=self.config.rollout_steps,
num_envs=self.runner.num_envs,
device=device,
)
self.model: SharedMLP = SharedMLP(
self.model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=self.config.hidden_sizes,
initial_log_std=0.5,
min_log_std=-2.0,
)
models = {
"policy": self.model,
"value": self.model,
}
models = {"policy": self.model, "value": self.model}
agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg.update({
@@ -217,9 +187,19 @@ class Trainer:
"ratio_clip": self.config.clip_ratio,
"value_loss_scale": self.config.value_loss_scale,
"entropy_loss_scale": self.config.entropy_loss_scale,
"state_preprocessor": RunningStandardScaler,
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
"value_preprocessor": RunningStandardScaler,
"value_preprocessor_kwargs": {"size": 1, "device": device},
})
# Wire up logging frequency: write_interval is in timesteps.
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
agent_cfg["experiment"]["checkpoint_interval"] = max(
self.config.total_timesteps // 10, self.config.rollout_steps
)
self.agent: PPO = PPO(
self.agent = PPO(
models=models,
memory=self.memory,
observation_space=obs_space,
@@ -238,6 +218,4 @@ class Trainer:
trainer.train()
def close(self) -> None:
self.runner.close()
if self.clearml_task:
self.clearml_task.close()
self.runner.close()

View File

@@ -1,39 +1,80 @@
import hydra
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
from src.core.env import ActuatorConfig
# ── env registry ──────────────────────────────────────────────────────
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
# Convert actuator dicts → ActuatorConfig objects
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return CartPoleConfig(**env_dict)
return env_cls(config_cls(**env_dict))
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags.
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
Tags: env, runner, training algo choices from Hydra.
"""
Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
project = "RL-Framework"
task_name = f"{env_name}-{runner_name}-{training_name}"
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
if remote:
task.execute_remotely(queue_name="default")
return task
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
env_config = _build_env_config(cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
choices = HydraConfig.get().runtime.choices
# ClearML init — must happen before heavy work so remote execution
# can take over early. The remote worker re-runs the full script;
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
# Build ClearML task name dynamically from Hydra config group choices
if not training_dict.get("clearml_task"):
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "env")
runner_name = choices.get("runner", "runner")
training_name = choices.get("training", "algo")
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
remote = training_dict.pop("remote", False)
task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole")
env = _build_env(env_name, cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
trainer_config = TrainerConfig(**training_dict)
env = CartPoleEnv(env_config)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config)
@@ -41,6 +82,7 @@ def main(cfg: DictConfig) -> None:
trainer.train()
finally:
trainer.close()
task.close()
if __name__ == "__main__":