✨ add rotary cartpole env
This commit is contained in:
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/arm_1.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
BIN
assets/rotary_cartpole/meshes/base_link.stl
Normal file
Binary file not shown.
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
BIN
assets/rotary_cartpole/meshes/pendulum_1.stl
Normal file
Binary file not shown.
105
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
105
assets/rotary_cartpole/rotary_cartpole.urdf
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<robot name="rotary_cartpole">
|
||||||
|
|
||||||
|
<!-- Fixed world frame -->
|
||||||
|
<link name="world"/>
|
||||||
|
|
||||||
|
<!-- Base: motor housing, fixed to world -->
|
||||||
|
<link name="base_link">
|
||||||
|
<inertial>
|
||||||
|
<origin xyz="-0.00011 0.00117 0.06055" rpy="0 0 0"/>
|
||||||
|
<mass value="0.921"/>
|
||||||
|
<inertia ixx="0.002385" iyy="0.002484" izz="0.000559"
|
||||||
|
ixy="0.0" iyz="-0.000149" ixz="6e-06"/>
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/base_link.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
|
||||||
|
<joint name="base_joint" type="fixed">
|
||||||
|
<parent link="world"/>
|
||||||
|
<child link="base_link"/>
|
||||||
|
</joint>
|
||||||
|
|
||||||
|
<!-- Arm: horizontal rotating arm driven by motor.
|
||||||
|
Real mass ~10g (Fusion assumed dense material, exported 279g). -->
|
||||||
|
<link name="arm">
|
||||||
|
<inertial>
|
||||||
|
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
||||||
|
<mass value="0.150"/>
|
||||||
|
<inertia ixx="4.05e-05" iyy="1.17e-05" izz="3.66e-05"
|
||||||
|
ixy="0.0" iyz="1.08e-07" ixz="0.0"/>
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/arm_1.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
|
||||||
|
<!-- Motor joint: base → arm, rotates around vertical z-axis -->
|
||||||
|
<joint name="motor_joint" type="revolute">
|
||||||
|
<origin xyz="-0.006947 0.00395 0.14796" rpy="0 0 0"/>
|
||||||
|
<parent link="base_link"/>
|
||||||
|
<child link="arm"/>
|
||||||
|
<axis xyz="0 0 1"/>
|
||||||
|
<limit lower="-1.5708" upper="1.5708" effort="10.0" velocity="200.0"/>
|
||||||
|
<dynamics damping="0.001"/>
|
||||||
|
</joint>
|
||||||
|
|
||||||
|
<!-- Pendulum: swings freely at the end of the arm.
|
||||||
|
Real mass: 5g pendulum + 10g weight at the tip (70mm from bearing) = 15g total.
|
||||||
|
(Fusion assumed dense material, exported 57g for the pendulum alone.) -->
|
||||||
|
<link name="pendulum">
|
||||||
|
<inertial>
|
||||||
|
<!-- Combined CoM: 5g rod (CoM ~35mm) + 10g tip weight at 70mm from pivot.
|
||||||
|
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
||||||
|
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
||||||
|
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
||||||
|
<origin xyz="0.0583 -0.0583 0.0" rpy="0 0 0"/>
|
||||||
|
<mass value="0.015"/>
|
||||||
|
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
||||||
|
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
||||||
|
</inertial>
|
||||||
|
<visual>
|
||||||
|
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</visual>
|
||||||
|
<collision>
|
||||||
|
<origin xyz="0.006895 -0.023224 -0.162953" rpy="0 0 0"/>
|
||||||
|
<geometry>
|
||||||
|
<mesh filename="meshes/pendulum_1.stl" scale="0.001 0.001 0.001"/>
|
||||||
|
</geometry>
|
||||||
|
</collision>
|
||||||
|
</link>
|
||||||
|
|
||||||
|
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
||||||
|
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). -->
|
||||||
|
<joint name="pendulum_joint" type="continuous">
|
||||||
|
<origin xyz="0.000052 0.019274 0.014993" rpy="0 0 0"/>
|
||||||
|
<parent link="arm"/>
|
||||||
|
<child link="pendulum"/>
|
||||||
|
<axis xyz="0 -1 0"/>
|
||||||
|
<dynamics damping="0.0005"/>
|
||||||
|
</joint>
|
||||||
|
|
||||||
|
</robot>
|
||||||
7
configs/env/rotary_cartpole.yaml
vendored
Normal file
7
configs/env/rotary_cartpole.yaml
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
max_steps: 1000
|
||||||
|
model_path: assets/rotary_cartpole/rotary_cartpole.urdf
|
||||||
|
reward_upright_scale: 1.0
|
||||||
|
actuators:
|
||||||
|
- joint: motor_joint
|
||||||
|
gear: 15.0
|
||||||
|
ctrl_range: [-1.0, 1.0]
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
num_envs: 16
|
num_envs: 64
|
||||||
device: cpu
|
device: cpu
|
||||||
dt: 0.02
|
dt: 0.002
|
||||||
substeps: 2
|
substeps: 20
|
||||||
|
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
hidden_sizes: [128, 128]
|
hidden_sizes: [128, 128]
|
||||||
total_timesteps: 1000000
|
total_timesteps: 5000000
|
||||||
rollout_steps: 1024
|
rollout_steps: 1024
|
||||||
learning_epochs: 4
|
learning_epochs: 4
|
||||||
mini_batches: 4
|
mini_batches: 4
|
||||||
@@ -8,6 +8,8 @@ gae_lambda: 0.95
|
|||||||
learning_rate: 0.0003
|
learning_rate: 0.0003
|
||||||
clip_ratio: 0.2
|
clip_ratio: 0.2
|
||||||
value_loss_scale: 0.5
|
value_loss_scale: 0.5
|
||||||
entropy_loss_scale: 0.01
|
entropy_loss_scale: 0.05
|
||||||
log_interval: 10
|
log_interval: 10
|
||||||
clearml_project: RL-Framework
|
|
||||||
|
# ClearML remote execution (GPU worker)
|
||||||
|
remote: false
|
||||||
|
|||||||
@@ -5,4 +5,6 @@ omegaconf
|
|||||||
mujoco
|
mujoco
|
||||||
skrl[torch]
|
skrl[torch]
|
||||||
clearml
|
clearml
|
||||||
|
imageio
|
||||||
|
imageio-ffmpeg
|
||||||
pytest
|
pytest
|
||||||
@@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]):
|
|||||||
|
|
||||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||||
return step_counts >= self.config.max_steps
|
return step_counts >= self.config.max_steps
|
||||||
|
|
||||||
|
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||||
|
"""Return the default joint positions for reset.
|
||||||
|
Override in subclass if the URDF zero pose doesn't match
|
||||||
|
the desired initial state (e.g. pendulum hanging down).
|
||||||
|
Returns None to use the URDF default (all zeros)."""
|
||||||
|
return None
|
||||||
|
|||||||
@@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]):
|
|||||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||||
|
|
||||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
def render(self, env_idx: int = 0):
|
||||||
raise NotImplementedError("Render method not implemented for this runner.")
|
"""Offscreen render → RGB numpy array. Override in subclass."""
|
||||||
|
raise NotImplementedError("Render not implemented for this runner.")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self._sim_close()
|
self._sim_close()
|
||||||
94
src/envs/rotary_cartpole.py
Normal file
94
src/envs/rotary_cartpole.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import dataclasses
|
||||||
|
import torch
|
||||||
|
from src.core.env import BaseEnv, BaseEnvConfig
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RotaryCartPoleState:
|
||||||
|
motor_angle: torch.Tensor # (num_envs,)
|
||||||
|
motor_vel: torch.Tensor # (num_envs,)
|
||||||
|
pendulum_angle: torch.Tensor # (num_envs,)
|
||||||
|
pendulum_vel: torch.Tensor # (num_envs,)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class RotaryCartPoleConfig(BaseEnvConfig):
|
||||||
|
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
||||||
|
|
||||||
|
The motor rotates the arm horizontally; the pendulum swings freely
|
||||||
|
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||||
|
"""
|
||||||
|
# Reward shaping
|
||||||
|
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||||
|
"""Furuta pendulum / rotary inverted pendulum environment.
|
||||||
|
|
||||||
|
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
||||||
|
|
||||||
|
Observations (6):
|
||||||
|
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
||||||
|
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
||||||
|
|
||||||
|
Actions (1):
|
||||||
|
Torque applied to the motor_joint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: RotaryCartPoleConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_space(self) -> spaces.Space:
|
||||||
|
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_space(self) -> spaces.Space:
|
||||||
|
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||||
|
|
||||||
|
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
||||||
|
return RotaryCartPoleState(
|
||||||
|
motor_angle=qpos[:, 0],
|
||||||
|
motor_vel=qvel[:, 0],
|
||||||
|
pendulum_angle=qpos[:, 1],
|
||||||
|
pendulum_vel=qvel[:, 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||||
|
return torch.stack([
|
||||||
|
torch.sin(state.motor_angle),
|
||||||
|
torch.cos(state.motor_angle),
|
||||||
|
torch.sin(state.pendulum_angle),
|
||||||
|
torch.cos(state.pendulum_angle),
|
||||||
|
state.motor_vel,
|
||||||
|
state.pendulum_vel,
|
||||||
|
], dim=-1)
|
||||||
|
|
||||||
|
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
# height: sin(θ) → -1 (down) to +1 (up)
|
||||||
|
height = torch.sin(state.pendulum_angle)
|
||||||
|
|
||||||
|
# Upright reward: strongly rewards being near vertical.
|
||||||
|
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||||
|
# down (h=-1): 0.0
|
||||||
|
# horiz (h= 0): 0.0
|
||||||
|
# up (h=+1): 1.0
|
||||||
|
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||||
|
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||||
|
|
||||||
|
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||||
|
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||||
|
|
||||||
|
return 5.0 * upright_reward - effort_penalty
|
||||||
|
|
||||||
|
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||||
|
# No early termination — episode runs for max_steps (truncation only).
|
||||||
|
# The agent must learn to swing up AND balance continuously.
|
||||||
|
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||||
|
|
||||||
|
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||||
|
# The STL mesh is horizontal at qpos=0.
|
||||||
|
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||||
|
import math
|
||||||
|
return [0.0, -math.pi / 2]
|
||||||
@@ -4,7 +4,7 @@ from gymnasium import spaces
|
|||||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||||
|
|
||||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||||
Model.__init__(self, observation_space, action_space, device)
|
Model.__init__(self, observation_space, action_space, device)
|
||||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||||
DeterministicMixin.__init__(self, clip_actions)
|
DeterministicMixin.__init__(self, clip_actions)
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import tempfile
|
import os
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from src.core.env import BaseEnv, ActuatorConfig
|
from src.core.env import BaseEnv, ActuatorConfig
|
||||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mujoco
|
import mujoco
|
||||||
import mujoco.viewer
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||||
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
|||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
dt: float = 0.02
|
dt: float = 0.02
|
||||||
substeps: int = 2
|
substeps: int = 2
|
||||||
|
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||||||
|
|
||||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||||
@@ -39,23 +39,53 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
This keeps the URDF clean and standard — actuator config lives in
|
This keeps the URDF clean and standard — actuator config lives in
|
||||||
the env config (Isaac Lab pattern), not in the robot file.
|
the env config (Isaac Lab pattern), not in the robot file.
|
||||||
"""
|
"""
|
||||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
abs_path = os.path.abspath(model_path)
|
||||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
model_dir = os.path.dirname(abs_path)
|
||||||
|
is_urdf = abs_path.lower().endswith(".urdf")
|
||||||
|
|
||||||
|
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||||
|
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||||
|
# temporary copy. The original URDF stays clean and simulator-agnostic.
|
||||||
|
if is_urdf:
|
||||||
|
tree = ET.parse(abs_path)
|
||||||
|
root = tree.getroot()
|
||||||
|
# Detect the mesh subdirectory from the first mesh filename
|
||||||
|
meshdir = None
|
||||||
|
for mesh_el in root.iter("mesh"):
|
||||||
|
fn = mesh_el.get("filename", "")
|
||||||
|
dirname = os.path.dirname(fn)
|
||||||
|
if dirname:
|
||||||
|
meshdir = dirname
|
||||||
|
break
|
||||||
|
if meshdir:
|
||||||
|
mj_ext = ET.SubElement(root, "mujoco")
|
||||||
|
ET.SubElement(mj_ext, "compiler", attrib={
|
||||||
|
"meshdir": meshdir,
|
||||||
|
"balanceinertia": "true",
|
||||||
|
})
|
||||||
|
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
||||||
|
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
||||||
|
try:
|
||||||
|
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp_urdf)
|
||||||
|
else:
|
||||||
|
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
||||||
|
|
||||||
if not actuators:
|
if not actuators:
|
||||||
return model_raw
|
return model_raw
|
||||||
|
|
||||||
# Step 2: Export internal MJCF representation
|
# Step 2: Export internal MJCF representation (save next to original
|
||||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
# model so relative mesh/asset paths resolve correctly on reload)
|
||||||
|
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
||||||
try:
|
try:
|
||||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||||
with open(tmp_mjcf) as f:
|
with open(tmp_mjcf) as f:
|
||||||
mjcf_str = f.read()
|
mjcf_str = f.read()
|
||||||
finally:
|
|
||||||
import os
|
|
||||||
os.unlink(tmp_mjcf)
|
|
||||||
|
|
||||||
# Step 3: Inject actuators into the MJCF XML
|
# Step 3: Inject actuators into the MJCF XML
|
||||||
|
# Use torque actuator. Speed is limited by joint damping:
|
||||||
|
# at steady state, vel_max = torque / damping.
|
||||||
root = ET.fromstring(mjcf_str)
|
root = ET.fromstring(mjcf_str)
|
||||||
act_elem = ET.SubElement(root, "actuator")
|
act_elem = ET.SubElement(root, "actuator")
|
||||||
for act in actuators:
|
for act in actuators:
|
||||||
@@ -66,9 +96,32 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||||
})
|
})
|
||||||
|
|
||||||
# Step 4: Reload from modified MJCF
|
# Add damping to actuated joints to limit max speed and
|
||||||
|
# mimic real motor friction / back-EMF.
|
||||||
|
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
||||||
|
actuated_joints = {a.joint for a in actuators}
|
||||||
|
for body in root.iter("body"):
|
||||||
|
for jnt in body.findall("joint"):
|
||||||
|
if jnt.get("name") in actuated_joints:
|
||||||
|
jnt.set("damping", "0.05")
|
||||||
|
|
||||||
|
# Disable self-collision on all geoms.
|
||||||
|
# URDF mesh convex hulls often overlap at joints (especially
|
||||||
|
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
|
||||||
|
# causing phantom contact forces.
|
||||||
|
for geom in root.iter("geom"):
|
||||||
|
geom.set("contype", "0")
|
||||||
|
geom.set("conaffinity", "0")
|
||||||
|
|
||||||
|
# Step 4: Write modified MJCF and reload from file path
|
||||||
|
# (from_xml_path resolves mesh paths relative to the file location)
|
||||||
modified_xml = ET.tostring(root, encoding="unicode")
|
modified_xml = ET.tostring(root, encoding="unicode")
|
||||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
with open(tmp_mjcf, "w") as f:
|
||||||
|
f.write(modified_xml)
|
||||||
|
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_mjcf):
|
||||||
|
os.unlink(tmp_mjcf)
|
||||||
|
|
||||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||||
model_path = self.env.config.model_path
|
model_path = self.env.config.model_path
|
||||||
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
self._nq = self._model.nq
|
self._nq = self._model.nq
|
||||||
self._nv = self._model.nv
|
self._nv = self._model.nv
|
||||||
|
|
||||||
|
# Per-env smoothed ctrl state for EMA action filtering.
|
||||||
|
# Models real motor inertia: ctrl can't reverse instantly.
|
||||||
|
nu = self._model.nu
|
||||||
|
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||||||
|
|
||||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
actions_np: np.ndarray = actions.cpu().numpy()
|
actions_np: np.ndarray = actions.cpu().numpy()
|
||||||
|
alpha = self.config.action_ema_alpha
|
||||||
|
|
||||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||||
|
|
||||||
for i, data in enumerate(self._data):
|
for i, data in enumerate(self._data):
|
||||||
data.ctrl[:] = actions_np[i]
|
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
||||||
|
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||||||
|
data.ctrl[:] = self._smooth_ctrl[i]
|
||||||
for _ in range(self.config.substeps):
|
for _ in range(self.config.substeps):
|
||||||
mujoco.mj_step(self._model, data)
|
mujoco.mj_step(self._model, data)
|
||||||
|
|
||||||
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||||
|
|
||||||
|
default_qpos = self.env.get_default_qpos(self._nq)
|
||||||
|
|
||||||
for i, env_id in enumerate(ids):
|
for i, env_id in enumerate(ids):
|
||||||
data = self._data[env_id]
|
data = self._data[env_id]
|
||||||
mujoco.mj_resetData(self._model, data)
|
mujoco.mj_resetData(self._model, data)
|
||||||
|
|
||||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
# Set initial pose (env-specific, e.g. pendulum hanging down)
|
||||||
|
if default_qpos is not None:
|
||||||
|
data.qpos[:] = default_qpos
|
||||||
|
|
||||||
|
# Add small random perturbation for exploration
|
||||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||||
|
|
||||||
|
# Reset smoothed ctrl so motor starts from rest
|
||||||
|
self._smooth_ctrl[env_id][:] = 0.0
|
||||||
|
|
||||||
qpos_batch[i] = data.qpos
|
qpos_batch[i] = data.qpos
|
||||||
qvel_batch[i] = data.qvel
|
qvel_batch[i] = data.qvel
|
||||||
|
|
||||||
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _sim_close(self) -> None:
|
def _sim_close(self) -> None:
|
||||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
|
||||||
self._viewer.close()
|
|
||||||
self._viewer = None
|
|
||||||
|
|
||||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||||
self._offscreen_renderer.close()
|
self._offscreen_renderer.close()
|
||||||
self._offscreen_renderer = None
|
self._offscreen_renderer = None
|
||||||
|
|
||||||
self._data.clear()
|
self._data.clear()
|
||||||
|
|
||||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
def render(self, env_idx: int = 0) -> np.ndarray | None:
|
||||||
if mode == "human":
|
"""Offscreen render → RGB numpy array (H, W, 3)."""
|
||||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
|
||||||
self._viewer = mujoco.viewer.launch_passive(
|
|
||||||
self._model, self._data[env_idx]
|
|
||||||
)
|
|
||||||
# Update visual geometry from current physics state
|
|
||||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
|
||||||
self._viewer.sync()
|
|
||||||
return None
|
|
||||||
elif mode == "rgb_array":
|
|
||||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
|
||||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
return self._offscreen_renderer.render().copy()
|
||||||
return torch.from_numpy(pixels)
|
|
||||||
@@ -4,19 +4,22 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from clearml import Logger
|
||||||
|
from gymnasium import spaces
|
||||||
|
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||||
|
from skrl.memories.torch import RandomMemory
|
||||||
|
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||||
|
from skrl.trainers.torch import SequentialTrainer
|
||||||
|
|
||||||
from src.core.runner import BaseRunner
|
from src.core.runner import BaseRunner
|
||||||
from clearml import Task, Logger
|
|
||||||
import torch
|
|
||||||
from gymnasium import spaces
|
|
||||||
from skrl.memories.torch import RandomMemory
|
|
||||||
from src.models.mlp import SharedMLP
|
from src.models.mlp import SharedMLP
|
||||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
|
||||||
from skrl.trainers.torch import SequentialTrainer
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class TrainerConfig:
|
class TrainerConfig:
|
||||||
|
# PPO
|
||||||
rollout_steps: int = 2048
|
rollout_steps: int = 2048
|
||||||
learning_epochs: int = 8
|
learning_epochs: int = 8
|
||||||
mini_batches: int = 4
|
mini_batches: int = 4
|
||||||
@@ -29,30 +32,27 @@ class TrainerConfig:
|
|||||||
|
|
||||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||||
|
|
||||||
|
# Training
|
||||||
total_timesteps: int = 1_000_000
|
total_timesteps: int = 1_000_000
|
||||||
log_interval: int = 10
|
log_interval: int = 10
|
||||||
|
|
||||||
# Video recording
|
# Video recording (uploaded to ClearML)
|
||||||
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
|
record_video_every: int = 10_000 # 0 = disabled
|
||||||
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
|
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||||
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
|
|
||||||
|
|
||||||
clearml_project: str | None = None
|
|
||||||
clearml_task: str | None = None
|
|
||||||
|
|
||||||
|
# ── Video-recording trainer ──────────────────────────────────────────
|
||||||
|
|
||||||
class VideoRecordingTrainer(SequentialTrainer):
|
class VideoRecordingTrainer(SequentialTrainer):
|
||||||
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
|
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
|
||||||
|
|
||||||
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
||||||
super().__init__(env=env, agents=agents, cfg=cfg)
|
super().__init__(env=env, agents=agents, cfg=cfg)
|
||||||
self._trainer_config = trainer_config
|
self._tcfg = trainer_config
|
||||||
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
||||||
|
|
||||||
def single_agent_train(self) -> None:
|
def single_agent_train(self) -> None:
|
||||||
"""Override to add periodic video recording."""
|
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
|
||||||
assert self.num_simultaneous_agents == 1
|
|
||||||
assert self.env.num_agents == 1
|
|
||||||
|
|
||||||
states, infos = self.env.reset()
|
states, infos = self.env.reset()
|
||||||
|
|
||||||
@@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
disable=self.disable_progressbar,
|
disable=self.disable_progressbar,
|
||||||
file=sys.stdout,
|
file=sys.stdout,
|
||||||
):
|
):
|
||||||
# Pre-interaction
|
|
||||||
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
||||||
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
||||||
|
|
||||||
if not self.headless:
|
|
||||||
self.env.render()
|
|
||||||
|
|
||||||
self.agents.record_transition(
|
self.agents.record_transition(
|
||||||
states=states,
|
states=states, actions=actions, rewards=rewards,
|
||||||
actions=actions,
|
next_states=next_states, terminated=terminated,
|
||||||
rewards=rewards,
|
truncated=truncated, infos=infos,
|
||||||
next_states=next_states,
|
timestep=timestep, timesteps=self.timesteps,
|
||||||
terminated=terminated,
|
|
||||||
truncated=truncated,
|
|
||||||
infos=infos,
|
|
||||||
timestep=timestep,
|
|
||||||
timesteps=self.timesteps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.environment_info in infos:
|
if self.environment_info in infos:
|
||||||
@@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
|
|
||||||
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||||
|
|
||||||
# Reset environments
|
# Auto-reset for multi-env; single-env resets manually
|
||||||
if self.env.num_envs > 1:
|
if self.env.num_envs > 1:
|
||||||
states = next_states
|
states = next_states
|
||||||
else:
|
else:
|
||||||
@@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
else:
|
else:
|
||||||
states = next_states
|
states = next_states
|
||||||
|
|
||||||
# Record video at intervals
|
# Periodic video recording
|
||||||
cfg = self._trainer_config
|
|
||||||
if (
|
if (
|
||||||
cfg
|
self._tcfg
|
||||||
and cfg.record_video_every > 0
|
and self._tcfg.record_video_every > 0
|
||||||
and (timestep + 1) % cfg.record_video_every == 0
|
and (timestep + 1) % self._tcfg.record_video_every == 0
|
||||||
):
|
):
|
||||||
self._record_video(timestep + 1)
|
self._record_video(timestep + 1)
|
||||||
|
|
||||||
def _get_video_fps(self) -> int:
|
# ── helpers ───────────────────────────────────────────────────────
|
||||||
"""Derive video fps from the simulation rate, or use configured value."""
|
|
||||||
cfg = self._trainer_config
|
def _get_fps(self) -> int:
|
||||||
if cfg.record_video_fps > 0:
|
if self._tcfg and self._tcfg.record_video_fps > 0:
|
||||||
return cfg.record_video_fps
|
return self._tcfg.record_video_fps
|
||||||
# Auto-derive from runner's simulation parameters
|
dt = getattr(self.env.config, "dt", 0.02)
|
||||||
runner = self.env
|
substeps = getattr(self.env.config, "substeps", 1)
|
||||||
dt = getattr(runner.config, "dt", 0.02)
|
|
||||||
substeps = getattr(runner.config, "substeps", 1)
|
|
||||||
return max(1, int(round(1.0 / (dt * substeps))))
|
return max(1, int(round(1.0 / (dt * substeps))))
|
||||||
|
|
||||||
def _record_video(self, timestep: int) -> None:
|
def _record_video(self, timestep: int) -> None:
|
||||||
"""Record evaluation episodes and upload to ClearML."""
|
|
||||||
try:
|
try:
|
||||||
import imageio.v3 as iio
|
import imageio.v3 as iio
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
import imageio as iio
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
return
|
||||||
|
|
||||||
cfg = self._trainer_config
|
fps = self._get_fps()
|
||||||
fps = self._get_video_fps()
|
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||||
min_frames = int(cfg.record_video_min_seconds * fps)
|
|
||||||
max_frames = min_frames * 3 # hard cap to prevent runaway recording
|
|
||||||
frames: list[np.ndarray] = []
|
frames: list[np.ndarray] = []
|
||||||
|
|
||||||
while len(frames) < min_frames and len(frames) < max_frames:
|
|
||||||
obs, _ = self.env.reset()
|
obs, _ = self.env.reset()
|
||||||
done = False
|
for _ in range(max_steps):
|
||||||
steps = 0
|
|
||||||
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
|
|
||||||
while not done and steps < max_episode_steps:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||||
frame = self.env.render(mode="rgb_array")
|
|
||||||
|
frame = self.env.render()
|
||||||
if frame is not None:
|
if frame is not None:
|
||||||
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
|
frames.append(frame)
|
||||||
done = (terminated | truncated).any().item()
|
|
||||||
steps += 1
|
if (terminated | truncated).any().item():
|
||||||
if len(frames) >= max_frames:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if frames:
|
if frames:
|
||||||
video_path = str(self._video_dir / f"step_{timestep}.mp4")
|
path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||||
iio.imwrite(video_path, frames, fps=fps)
|
iio.imwrite(path, frames, fps=fps)
|
||||||
|
|
||||||
logger = Logger.current_logger()
|
logger = Logger.current_logger()
|
||||||
if logger:
|
if logger:
|
||||||
logger.report_media(
|
logger.report_media(
|
||||||
title="Training Video",
|
"Training Video", f"step_{timestep}",
|
||||||
series=f"step_{timestep}",
|
local_path=path, iteration=timestep,
|
||||||
local_path=video_path,
|
|
||||||
iteration=timestep,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset back to training state after recording
|
# Restore training state
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main trainer ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self._init_clearml()
|
|
||||||
self._init_agent()
|
self._init_agent()
|
||||||
|
|
||||||
def _init_clearml(self) -> None:
|
|
||||||
if self.config.clearml_project and self.config.clearml_task:
|
|
||||||
self.clearml_task = Task.init(
|
|
||||||
project_name=self.config.clearml_project,
|
|
||||||
task_name=self.config.clearml_task,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.clearml_task = None
|
|
||||||
|
|
||||||
def _init_agent(self) -> None:
|
def _init_agent(self) -> None:
|
||||||
device: torch.device = self.runner.device
|
device = self.runner.device
|
||||||
obs_space: spaces.Space = self.runner.observation_space
|
obs_space = self.runner.observation_space
|
||||||
act_space: spaces.Space = self.runner.action_space
|
act_space = self.runner.action_space
|
||||||
num_envs: int = self.runner.num_envs
|
|
||||||
|
|
||||||
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
|
self.memory = RandomMemory(
|
||||||
|
memory_size=self.config.rollout_steps,
|
||||||
|
num_envs=self.runner.num_envs,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
self.model: SharedMLP = SharedMLP(
|
self.model = SharedMLP(
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
action_space=act_space,
|
action_space=act_space,
|
||||||
device=device,
|
device=device,
|
||||||
hidden_sizes=self.config.hidden_sizes,
|
hidden_sizes=self.config.hidden_sizes,
|
||||||
|
initial_log_std=0.5,
|
||||||
|
min_log_std=-2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
models = {
|
models = {"policy": self.model, "value": self.model}
|
||||||
"policy": self.model,
|
|
||||||
"value": self.model,
|
|
||||||
}
|
|
||||||
|
|
||||||
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
||||||
agent_cfg.update({
|
agent_cfg.update({
|
||||||
@@ -217,9 +187,19 @@ class Trainer:
|
|||||||
"ratio_clip": self.config.clip_ratio,
|
"ratio_clip": self.config.clip_ratio,
|
||||||
"value_loss_scale": self.config.value_loss_scale,
|
"value_loss_scale": self.config.value_loss_scale,
|
||||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||||
|
"state_preprocessor": RunningStandardScaler,
|
||||||
|
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
||||||
|
"value_preprocessor": RunningStandardScaler,
|
||||||
|
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
||||||
})
|
})
|
||||||
|
# Wire up logging frequency: write_interval is in timesteps.
|
||||||
|
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||||
|
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||||
|
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||||
|
self.config.total_timesteps // 10, self.config.rollout_steps
|
||||||
|
)
|
||||||
|
|
||||||
self.agent: PPO = PPO(
|
self.agent = PPO(
|
||||||
models=models,
|
models=models,
|
||||||
memory=self.memory,
|
memory=self.memory,
|
||||||
observation_space=obs_space,
|
observation_space=obs_space,
|
||||||
@@ -239,5 +219,3 @@ class Trainer:
|
|||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.runner.close()
|
self.runner.close()
|
||||||
if self.clearml_task:
|
|
||||||
self.clearml_task.close()
|
|
||||||
70
train.py
70
train.py
@@ -1,39 +1,80 @@
|
|||||||
import hydra
|
import hydra
|
||||||
|
from clearml import Task
|
||||||
from hydra.core.hydra_config import HydraConfig
|
from hydra.core.hydra_config import HydraConfig
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||||
|
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||||
from src.training.trainer import Trainer, TrainerConfig
|
from src.training.trainer import Trainer, TrainerConfig
|
||||||
from src.core.env import ActuatorConfig
|
|
||||||
|
# ── env registry ──────────────────────────────────────────────────────
|
||||||
|
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||||
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||||
|
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||||
|
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||||
|
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||||
|
if env_name not in ENV_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||||
|
|
||||||
|
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||||
|
|
||||||
|
# Convert actuator dicts → ActuatorConfig objects
|
||||||
if "actuators" in env_dict:
|
if "actuators" in env_dict:
|
||||||
for a in env_dict["actuators"]:
|
for a in env_dict["actuators"]:
|
||||||
if "ctrl_range" in a:
|
if "ctrl_range" in a:
|
||||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||||
return CartPoleConfig(**env_dict)
|
|
||||||
|
return env_cls(config_cls(**env_dict))
|
||||||
|
|
||||||
|
|
||||||
|
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||||
|
"""Initialize ClearML task with project structure and tags.
|
||||||
|
|
||||||
|
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
|
||||||
|
Tags: env, runner, training algo choices from Hydra.
|
||||||
|
"""
|
||||||
|
Task.ignore_requirements("torch")
|
||||||
|
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
runner_name = choices.get("runner", "mujoco")
|
||||||
|
training_name = choices.get("training", "ppo")
|
||||||
|
|
||||||
|
project = "RL-Framework"
|
||||||
|
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||||
|
tags = [env_name, runner_name, training_name]
|
||||||
|
|
||||||
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||||
|
|
||||||
|
if remote:
|
||||||
|
task.execute_remotely(queue_name="default")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
env_config = _build_env_config(cfg)
|
|
||||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
|
||||||
|
|
||||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
|
||||||
# Build ClearML task name dynamically from Hydra config group choices
|
|
||||||
if not training_dict.get("clearml_task"):
|
|
||||||
choices = HydraConfig.get().runtime.choices
|
choices = HydraConfig.get().runtime.choices
|
||||||
env_name = choices.get("env", "env")
|
|
||||||
runner_name = choices.get("runner", "runner")
|
# ClearML init — must happen before heavy work so remote execution
|
||||||
training_name = choices.get("training", "algo")
|
# can take over early. The remote worker re-runs the full script;
|
||||||
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
# execute_remotely() is a no-op on the worker side.
|
||||||
|
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||||
|
remote = training_dict.pop("remote", False)
|
||||||
|
task = _init_clearml(choices, remote=remote)
|
||||||
|
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
env = _build_env(env_name, cfg)
|
||||||
|
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||||
trainer_config = TrainerConfig(**training_dict)
|
trainer_config = TrainerConfig(**training_dict)
|
||||||
|
|
||||||
env = CartPoleEnv(env_config)
|
|
||||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||||
trainer = Trainer(runner=runner, config=trainer_config)
|
trainer = Trainer(runner=runner, config=trainer_config)
|
||||||
|
|
||||||
@@ -41,6 +82,7 @@ def main(cfg: DictConfig) -> None:
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
finally:
|
finally:
|
||||||
trainer.close()
|
trainer.close()
|
||||||
|
task.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user