✨ add rotary cartpole env
This commit is contained in:
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
Binary file not shown.
105
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
105
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
@@ -0,0 +1,105 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<robot name="rotary_cartpole">
|
||||
|
||||
<!-- Fixed world frame -->
|
||||
<link name="world"/>
|
||||
|
||||
<!-- Base: motor housing, fixed to world -->
|
||||
<link name="base_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0"/>
|
||||
<mass value="0.921"/>
|
||||
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559"
|
||||
ixy="0.0" iyz="-0.000149" ixz="6e-06"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<joint name="base_joint" type="fixed">
|
||||
<parent link="world"/>
|
||||
<child link="base_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- Arm: horizontal rotating arm driven by motor.
|
||||
Real mass ~10g (Fusion assumed dense material, exported 279g). -->
|
||||
<link name="arm">
|
||||
<inertial>
|
||||
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
||||
<mass value="0.150"/>
|
||||
<inertia ixx="4.05e-05" iyy="1.17e-05" izz="3.66e-05"
|
||||
ixy="0.0" iyz="1.08e-07" ixz="0.0"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Motor joint: base → arm, rotates around vertical z-axis -->
|
||||
<joint name="motor_joint" type="revolute">
|
||||
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0"/>
|
||||
<parent link="base_link"/>
|
||||
<child link="arm"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0"/>
|
||||
<dynamics damping="0.001"/>
|
||||
</joint>
|
||||
|
||||
<!-- Pendulum: swings freely at the end of the arm.
|
||||
Real mass: 5g pendulum + 10g weight at the tip (70mm from bearing) = 15g total.
|
||||
(Fusion assumed dense material, exported 57g for the pendulum alone.) -->
|
||||
<link name="pendulum">
|
||||
<inertial>
|
||||
<!-- Combined CoM: 5g rod (CoM ~35mm) + 10g tip weight at 70mm from pivot.
|
||||
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
||||
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
||||
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
||||
<origin xyz="0.0583 -0.0583 0.0" rpy="0 0 0"/>
|
||||
<mass value="0.015"/>
|
||||
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
||||
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
||||
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). -->
|
||||
<joint name="pendulum_joint" type="continuous">
|
||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 0 0"/>
|
||||
<parent link="arm"/>
|
||||
<child link="pendulum"/>
|
||||
<axis xyz="0 -1 0"/>
|
||||
<dynamics damping="0.0005"/>
|
||||
</joint>
|
||||
|
||||
</robot>
|
||||
7
configs/env/rotary_cartpole.yaml
vendored
Normal file
7
configs/env/rotary_cartpole.yaml
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
max_steps: 1000
|
||||
model_path: assets/rotary_cartpole/rotary_cartpole.urdf
|
||||
reward_upright_scale: 1.0
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
gear: 15.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
@@ -1,4 +1,5 @@
|
||||
num_envs: 16
|
||||
num_envs: 64
|
||||
device: cpu
|
||||
dt: 0.02
|
||||
substeps: 2
|
||||
dt: 0.002
|
||||
substeps: 20
|
||||
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
hidden_sizes: [128, 128]
|
||||
total_timesteps: 1000000
|
||||
total_timesteps: 5000000
|
||||
rollout_steps: 1024
|
||||
learning_epochs: 4
|
||||
mini_batches: 4
|
||||
@@ -8,6 +8,8 @@ gae_lambda: 0.95
|
||||
learning_rate: 0.0003
|
||||
clip_ratio: 0.2
|
||||
value_loss_scale: 0.5
|
||||
entropy_loss_scale: 0.01
|
||||
entropy_loss_scale: 0.05
|
||||
log_interval: 10
|
||||
clearml_project: RL-Framework
|
||||
|
||||
# ClearML remote execution (GPU worker)
|
||||
remote: false
|
||||
|
||||
@@ -5,4 +5,6 @@ omegaconf
|
||||
mujoco
|
||||
skrl[torch]
|
||||
clearml
|
||||
imageio
|
||||
imageio-ffmpeg
|
||||
pytest
|
||||
@@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]):
|
||||
|
||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||
return step_counts >= self.config.max_steps
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
"""Return the default joint positions for reset.
|
||||
Override in subclass if the URDF zero pose doesn't match
|
||||
the desired initial state (e.g. pendulum hanging down).
|
||||
Returns None to use the URDF default (all zeros)."""
|
||||
return None
|
||||
|
||||
@@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
raise NotImplementedError("Render method not implemented for this runner.")
|
||||
def render(self, env_idx: int = 0):
|
||||
"""Offscreen render → RGB numpy array. Override in subclass."""
|
||||
raise NotImplementedError("Render not implemented for this runner.")
|
||||
|
||||
def close(self) -> None:
|
||||
self._sim_close()
|
||||
94
src/envs/rotary_cartpole.py
Normal file
94
src/envs/rotary_cartpole.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleState:
|
||||
motor_angle: torch.Tensor # (num_envs,)
|
||||
motor_vel: torch.Tensor # (num_envs,)
|
||||
pendulum_angle: torch.Tensor # (num_envs,)
|
||||
pendulum_vel: torch.Tensor # (num_envs,)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleConfig(BaseEnvConfig):
|
||||
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
||||
|
||||
The motor rotates the arm horizontally; the pendulum swings freely
|
||||
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||
"""
|
||||
# Reward shaping
|
||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||
|
||||
|
||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
"""Furuta pendulum / rotary inverted pendulum environment.
|
||||
|
||||
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
||||
|
||||
Observations (6):
|
||||
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
||||
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
||||
|
||||
Actions (1):
|
||||
Torque applied to the motor_joint.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RotaryCartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=qpos[:, 0],
|
||||
motor_vel=qvel[:, 0],
|
||||
pendulum_angle=qpos[:, 1],
|
||||
pendulum_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
return torch.stack([
|
||||
torch.sin(state.motor_angle),
|
||||
torch.cos(state.motor_angle),
|
||||
torch.sin(state.pendulum_angle),
|
||||
torch.cos(state.pendulum_angle),
|
||||
state.motor_vel,
|
||||
state.pendulum_vel,
|
||||
], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
# height: sin(θ) → -1 (down) to +1 (up)
|
||||
height = torch.sin(state.pendulum_angle)
|
||||
|
||||
# Upright reward: strongly rewards being near vertical.
|
||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||
# down (h=-1): 0.0
|
||||
# horiz (h= 0): 0.0
|
||||
# up (h=+1): 1.0
|
||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||
|
||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||
|
||||
return 5.0 * upright_reward - effort_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
# The agent must learn to swing up AND balance continuously.
|
||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
# The STL mesh is horizontal at qpos=0.
|
||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||
import math
|
||||
return [0.0, -math.pi / 2]
|
||||
@@ -4,7 +4,7 @@ from gymnasium import spaces
|
||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||
|
||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
Model.__init__(self, observation_space, action_space, device)
|
||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||
DeterministicMixin.__init__(self, clip_actions)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import dataclasses
|
||||
import tempfile
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
@@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
||||
abs_path = os.path.abspath(model_path)
|
||||
model_dir = os.path.dirname(abs_path)
|
||||
is_urdf = abs_path.lower().endswith(".urdf")
|
||||
|
||||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||
# temporary copy. The original URDF stays clean and simulator-agnostic.
|
||||
if is_urdf:
|
||||
tree = ET.parse(abs_path)
|
||||
root = tree.getroot()
|
||||
# Detect the mesh subdirectory from the first mesh filename
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
dirname = os.path.dirname(fn)
|
||||
if dirname:
|
||||
meshdir = dirname
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
ET.SubElement(mj_ext, "compiler", attrib={
|
||||
"meshdir": meshdir,
|
||||
"balanceinertia": "true",
|
||||
})
|
||||
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
||||
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
||||
try:
|
||||
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
||||
finally:
|
||||
os.unlink(tmp_urdf)
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation
|
||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
||||
# Step 2: Export internal MJCF representation (save next to original
|
||||
# model so relative mesh/asset paths resolve correctly on reload)
|
||||
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
# Use torque actuator. Speed is limited by joint damping:
|
||||
# at steady state, vel_max = torque / damping.
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Add damping to actuated joints to limit max speed and
|
||||
# mimic real motor friction / back-EMF.
|
||||
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
||||
actuated_joints = {a.joint for a in actuators}
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("name") in actuated_joints:
|
||||
jnt.set("damping", "0.05")
|
||||
|
||||
# Disable self-collision on all geoms.
|
||||
# URDF mesh convex hulls often overlap at joints (especially
|
||||
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
|
||||
# causing phantom contact forces.
|
||||
for geom in root.iter("geom"):
|
||||
geom.set("contype", "0")
|
||||
geom.set("conaffinity", "0")
|
||||
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
with open(tmp_mjcf, "w") as f:
|
||||
f.write(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Step 4: Reload from modified MJCF
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
||||
if os.path.exists(tmp_mjcf):
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
|
||||
# Per-env smoothed ctrl state for EMA action filtering.
|
||||
# Models real motor inertia: ctrl can't reverse instantly.
|
||||
nu = self._model.nu
|
||||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
alpha = self.config.action_ema_alpha
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
data.ctrl[:] = actions_np[i]
|
||||
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
||||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||||
data.ctrl[:] = self._smooth_ctrl[i]
|
||||
for _ in range(self.config.substeps):
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||
|
||||
default_qpos = self.env.get_default_qpos(self._nq)
|
||||
|
||||
for i, env_id in enumerate(ids):
|
||||
data = self._data[env_id]
|
||||
mujoco.mj_resetData(self._model, data)
|
||||
|
||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
||||
# Set initial pose (env-specific, e.g. pendulum hanging down)
|
||||
if default_qpos is not None:
|
||||
data.qpos[:] = default_qpos
|
||||
|
||||
# Add small random perturbation for exploration
|
||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
# Reset smoothed ctrl so motor starts from rest
|
||||
self._smooth_ctrl[env_id][:] = 0.0
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
||||
self._viewer.close()
|
||||
self._viewer = None
|
||||
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
self._offscreen_renderer = None
|
||||
|
||||
self._data.clear()
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
if mode == "human":
|
||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
||||
self._viewer = mujoco.viewer.launch_passive(
|
||||
self._model, self._data[env_idx]
|
||||
)
|
||||
# Update visual geometry from current physics state
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._viewer.sync()
|
||||
return None
|
||||
elif mode == "rgb_array":
|
||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
||||
return torch.from_numpy(pixels)
|
||||
def render(self, env_idx: int = 0) -> np.ndarray | None:
|
||||
"""Offscreen render → RGB numpy array (H, W, 3)."""
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
return self._offscreen_renderer.render().copy()
|
||||
@@ -4,19 +4,22 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from clearml import Logger
|
||||
from gymnasium import spaces
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
from src.core.runner import BaseRunner
|
||||
from clearml import Task, Logger
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from src.models.mlp import SharedMLP
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerConfig:
|
||||
# PPO
|
||||
rollout_steps: int = 2048
|
||||
learning_epochs: int = 8
|
||||
mini_batches: int = 4
|
||||
@@ -29,30 +32,27 @@ class TrainerConfig:
|
||||
|
||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
|
||||
# Video recording
|
||||
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
|
||||
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
|
||||
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
|
||||
# Video recording (uploaded to ClearML)
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||
|
||||
clearml_project: str | None = None
|
||||
clearml_task: str | None = None
|
||||
|
||||
# ── Video-recording trainer ──────────────────────────────────────────
|
||||
|
||||
class VideoRecordingTrainer(SequentialTrainer):
|
||||
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
|
||||
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
|
||||
|
||||
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
||||
super().__init__(env=env, agents=agents, cfg=cfg)
|
||||
self._trainer_config = trainer_config
|
||||
self._tcfg = trainer_config
|
||||
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
||||
|
||||
def single_agent_train(self) -> None:
|
||||
"""Override to add periodic video recording."""
|
||||
assert self.num_simultaneous_agents == 1
|
||||
assert self.env.num_agents == 1
|
||||
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
|
||||
|
||||
states, infos = self.env.reset()
|
||||
|
||||
@@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
disable=self.disable_progressbar,
|
||||
file=sys.stdout,
|
||||
):
|
||||
# Pre-interaction
|
||||
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
||||
|
||||
if not self.headless:
|
||||
self.env.render()
|
||||
|
||||
self.agents.record_transition(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
infos=infos,
|
||||
timestep=timestep,
|
||||
timesteps=self.timesteps,
|
||||
states=states, actions=actions, rewards=rewards,
|
||||
next_states=next_states, terminated=terminated,
|
||||
truncated=truncated, infos=infos,
|
||||
timestep=timestep, timesteps=self.timesteps,
|
||||
)
|
||||
|
||||
if self.environment_info in infos:
|
||||
@@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
|
||||
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
# Reset environments
|
||||
# Auto-reset for multi-env; single-env resets manually
|
||||
if self.env.num_envs > 1:
|
||||
states = next_states
|
||||
else:
|
||||
@@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
else:
|
||||
states = next_states
|
||||
|
||||
# Record video at intervals
|
||||
cfg = self._trainer_config
|
||||
# Periodic video recording
|
||||
if (
|
||||
cfg
|
||||
and cfg.record_video_every > 0
|
||||
and (timestep + 1) % cfg.record_video_every == 0
|
||||
self._tcfg
|
||||
and self._tcfg.record_video_every > 0
|
||||
and (timestep + 1) % self._tcfg.record_video_every == 0
|
||||
):
|
||||
self._record_video(timestep + 1)
|
||||
|
||||
def _get_video_fps(self) -> int:
|
||||
"""Derive video fps from the simulation rate, or use configured value."""
|
||||
cfg = self._trainer_config
|
||||
if cfg.record_video_fps > 0:
|
||||
return cfg.record_video_fps
|
||||
# Auto-derive from runner's simulation parameters
|
||||
runner = self.env
|
||||
dt = getattr(runner.config, "dt", 0.02)
|
||||
substeps = getattr(runner.config, "substeps", 1)
|
||||
# ── helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _get_fps(self) -> int:
|
||||
if self._tcfg and self._tcfg.record_video_fps > 0:
|
||||
return self._tcfg.record_video_fps
|
||||
dt = getattr(self.env.config, "dt", 0.02)
|
||||
substeps = getattr(self.env.config, "substeps", 1)
|
||||
return max(1, int(round(1.0 / (dt * substeps))))
|
||||
|
||||
def _record_video(self, timestep: int) -> None:
|
||||
"""Record evaluation episodes and upload to ClearML."""
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
except ImportError:
|
||||
try:
|
||||
import imageio as iio
|
||||
except ImportError:
|
||||
return
|
||||
return
|
||||
|
||||
cfg = self._trainer_config
|
||||
fps = self._get_video_fps()
|
||||
min_frames = int(cfg.record_video_min_seconds * fps)
|
||||
max_frames = min_frames * 3 # hard cap to prevent runaway recording
|
||||
fps = self._get_fps()
|
||||
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
frames: list[np.ndarray] = []
|
||||
|
||||
while len(frames) < min_frames and len(frames) < max_frames:
|
||||
obs, _ = self.env.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
while not done and steps < max_episode_steps:
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
frame = self.env.render(mode="rgb_array")
|
||||
if frame is not None:
|
||||
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
|
||||
done = (terminated | truncated).any().item()
|
||||
steps += 1
|
||||
if len(frames) >= max_frames:
|
||||
break
|
||||
obs, _ = self.env.reset()
|
||||
for _ in range(max_steps):
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
|
||||
frame = self.env.render()
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
if (terminated | truncated).any().item():
|
||||
break
|
||||
|
||||
if frames:
|
||||
video_path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(video_path, frames, fps=fps)
|
||||
path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(path, frames, fps=fps)
|
||||
|
||||
logger = Logger.current_logger()
|
||||
if logger:
|
||||
logger.report_media(
|
||||
title="Training Video",
|
||||
series=f"step_{timestep}",
|
||||
local_path=video_path,
|
||||
iteration=timestep,
|
||||
"Training Video", f"step_{timestep}",
|
||||
local_path=path, iteration=timestep,
|
||||
)
|
||||
|
||||
# Reset back to training state after recording
|
||||
# Restore training state
|
||||
self.env.reset()
|
||||
|
||||
|
||||
# ── Main trainer ─────────────────────────────────────────────────────
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
|
||||
self._init_clearml()
|
||||
self._init_agent()
|
||||
|
||||
def _init_clearml(self) -> None:
|
||||
if self.config.clearml_project and self.config.clearml_task:
|
||||
self.clearml_task = Task.init(
|
||||
project_name=self.config.clearml_project,
|
||||
task_name=self.config.clearml_task,
|
||||
)
|
||||
else:
|
||||
self.clearml_task = None
|
||||
|
||||
def _init_agent(self) -> None:
|
||||
device: torch.device = self.runner.device
|
||||
obs_space: spaces.Space = self.runner.observation_space
|
||||
act_space: spaces.Space = self.runner.action_space
|
||||
num_envs: int = self.runner.num_envs
|
||||
device = self.runner.device
|
||||
obs_space = self.runner.observation_space
|
||||
act_space = self.runner.action_space
|
||||
|
||||
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
|
||||
self.memory = RandomMemory(
|
||||
memory_size=self.config.rollout_steps,
|
||||
num_envs=self.runner.num_envs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.model: SharedMLP = SharedMLP(
|
||||
self.model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=self.config.hidden_sizes,
|
||||
initial_log_std=0.5,
|
||||
min_log_std=-2.0,
|
||||
)
|
||||
|
||||
models = {
|
||||
"policy": self.model,
|
||||
"value": self.model,
|
||||
}
|
||||
models = {"policy": self.model, "value": self.model}
|
||||
|
||||
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
||||
agent_cfg.update({
|
||||
@@ -217,9 +187,19 @@ class Trainer:
|
||||
"ratio_clip": self.config.clip_ratio,
|
||||
"value_loss_scale": self.config.value_loss_scale,
|
||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||
"state_preprocessor": RunningStandardScaler,
|
||||
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
||||
"value_preprocessor": RunningStandardScaler,
|
||||
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
||||
})
|
||||
# Wire up logging frequency: write_interval is in timesteps.
|
||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||
self.config.total_timesteps // 10, self.config.rollout_steps
|
||||
)
|
||||
|
||||
self.agent: PPO = PPO(
|
||||
self.agent = PPO(
|
||||
models=models,
|
||||
memory=self.memory,
|
||||
observation_space=obs_space,
|
||||
@@ -238,6 +218,4 @@ class Trainer:
|
||||
trainer.train()
|
||||
|
||||
def close(self) -> None:
|
||||
self.runner.close()
|
||||
if self.clearml_task:
|
||||
self.clearml_task.close()
|
||||
self.runner.close()
|
||||
68
train.py
68
train.py
@@ -1,39 +1,80 @@
|
||||
import hydra
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
from src.core.env import ActuatorConfig
|
||||
|
||||
# ── env registry ──────────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
|
||||
# Convert actuator dicts → ActuatorConfig objects
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return CartPoleConfig(**env_dict)
|
||||
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags.
|
||||
|
||||
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
|
||||
Tags: env, runner, training algo choices from Hydra.
|
||||
"""
|
||||
Task.ignore_requirements("torch")
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
training_name = choices.get("training", "ppo")
|
||||
|
||||
project = "RL-Framework"
|
||||
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
|
||||
if remote:
|
||||
task.execute_remotely(queue_name="default")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
env_config = _build_env_config(cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
|
||||
# ClearML init — must happen before heavy work so remote execution
|
||||
# can take over early. The remote worker re-runs the full script;
|
||||
# execute_remotely() is a no-op on the worker side.
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
# Build ClearML task name dynamically from Hydra config group choices
|
||||
if not training_dict.get("clearml_task"):
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "env")
|
||||
runner_name = choices.get("runner", "runner")
|
||||
training_name = choices.get("training", "algo")
|
||||
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
||||
remote = training_dict.pop("remote", False)
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = _build_env(env_name, cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
env = CartPoleEnv(env_config)
|
||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
@@ -41,6 +82,7 @@ def main(cfg: DictConfig) -> None:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
task.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user