diff --git a/assets/rotary_cartpole/meshes/arm_1.stl b/assets/rotary_cartpole/meshes/arm_1.stl
new file mode 100644
index 0000000..2cfeba8
Binary files /dev/null and b/assets/rotary_cartpole/meshes/arm_1.stl differ
diff --git a/assets/rotary_cartpole/meshes/base_link.stl b/assets/rotary_cartpole/meshes/base_link.stl
new file mode 100644
index 0000000..d56b548
Binary files /dev/null and b/assets/rotary_cartpole/meshes/base_link.stl differ
diff --git a/assets/rotary_cartpole/meshes/pendulum_1.stl b/assets/rotary_cartpole/meshes/pendulum_1.stl
new file mode 100644
index 0000000..73d39ac
Binary files /dev/null and b/assets/rotary_cartpole/meshes/pendulum_1.stl differ
diff --git a/assets/rotary_cartpole/rotary_cartpole.urdf b/assets/rotary_cartpole/rotary_cartpole.urdf
new file mode 100644
index 0000000..b68ed3b
--- /dev/null
+++ b/assets/rotary_cartpole/rotary_cartpole.urdf
@@ -0,0 +1,105 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml
new file mode 100644
index 0000000..c4fb61c
--- /dev/null
+++ b/configs/env/rotary_cartpole.yaml
@@ -0,0 +1,7 @@
+max_steps: 1000
+model_path: assets/rotary_cartpole/rotary_cartpole.urdf
+reward_upright_scale: 1.0
+actuators:
+ - joint: motor_joint
+ gear: 15.0
+ ctrl_range: [-1.0, 1.0]
diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml
index 42a54b5..861b8ba 100644
--- a/configs/runner/mujoco.yaml
+++ b/configs/runner/mujoco.yaml
@@ -1,4 +1,5 @@
-num_envs: 16
+num_envs: 64
device: cpu
-dt: 0.02
-substeps: 2
+dt: 0.002
+substeps: 20
+action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml
index 8025216..c078de6 100644
--- a/configs/training/ppo.yaml
+++ b/configs/training/ppo.yaml
@@ -1,5 +1,5 @@
hidden_sizes: [128, 128]
-total_timesteps: 1000000
+total_timesteps: 5000000
rollout_steps: 1024
learning_epochs: 4
mini_batches: 4
@@ -8,6 +8,8 @@ gae_lambda: 0.95
learning_rate: 0.0003
clip_ratio: 0.2
value_loss_scale: 0.5
-entropy_loss_scale: 0.01
+entropy_loss_scale: 0.05
log_interval: 10
-clearml_project: RL-Framework
+
+# ClearML remote execution (GPU worker)
+remote: false
diff --git a/requirements.txt b/requirements.txt
index f40d2e8..f3ed7df 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,4 +5,6 @@ omegaconf
mujoco
skrl[torch]
clearml
+imageio
+imageio-ffmpeg
pytest
\ No newline at end of file
diff --git a/src/core/env.py b/src/core/env.py
index 2654898..80c5d96 100644
--- a/src/core/env.py
+++ b/src/core/env.py
@@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]):
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
return step_counts >= self.config.max_steps
+
+ def get_default_qpos(self, nq: int) -> list[float] | None:
+ """Return the default joint positions for reset.
+ Override in subclass if the URDF zero pose doesn't match
+ the desired initial state (e.g. pendulum hanging down).
+ Returns None to use the URDF default (all zeros)."""
+ return None
diff --git a/src/core/runner.py b/src/core/runner.py
index f3e17fa..ff00890 100644
--- a/src/core/runner.py
+++ b/src/core/runner.py
@@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]):
# skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
- def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
- raise NotImplementedError("Render method not implemented for this runner.")
+ def render(self, env_idx: int = 0):
+ """Offscreen render → RGB numpy array. Override in subclass."""
+ raise NotImplementedError("Render not implemented for this runner.")
def close(self) -> None:
self._sim_close()
\ No newline at end of file
diff --git a/src/envs/rotary_cartpole.py b/src/envs/rotary_cartpole.py
new file mode 100644
index 0000000..d70dcf2
--- /dev/null
+++ b/src/envs/rotary_cartpole.py
@@ -0,0 +1,94 @@
+import dataclasses
+import torch
+from src.core.env import BaseEnv, BaseEnvConfig
+from gymnasium import spaces
+
+
+@dataclasses.dataclass
+class RotaryCartPoleState:
+ motor_angle: torch.Tensor # (num_envs,)
+ motor_vel: torch.Tensor # (num_envs,)
+ pendulum_angle: torch.Tensor # (num_envs,)
+ pendulum_vel: torch.Tensor # (num_envs,)
+
+
+@dataclasses.dataclass
+class RotaryCartPoleConfig(BaseEnvConfig):
+ """Rotary inverted pendulum (Furuta pendulum) task config.
+
+ The motor rotates the arm horizontally; the pendulum swings freely
+ at the arm tip. Goal: swing the pendulum up and balance it upright.
+ """
+ # Reward shaping
+ reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
+
+
+class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
+ """Furuta pendulum / rotary inverted pendulum environment.
+
+ Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
+
+ Observations (6):
+ [sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
+ Using sin/cos avoids discontinuities at ±π for continuous joints.
+
+ Actions (1):
+ Torque applied to the motor_joint.
+ """
+
+ def __init__(self, config: RotaryCartPoleConfig):
+ super().__init__(config)
+
+ @property
+ def observation_space(self) -> spaces.Space:
+ return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
+
+ @property
+ def action_space(self) -> spaces.Space:
+ return spaces.Box(low=-1.0, high=1.0, shape=(1,))
+
+ def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
+ return RotaryCartPoleState(
+ motor_angle=qpos[:, 0],
+ motor_vel=qvel[:, 0],
+ pendulum_angle=qpos[:, 1],
+ pendulum_vel=qvel[:, 1],
+ )
+
+ def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
+ return torch.stack([
+ torch.sin(state.motor_angle),
+ torch.cos(state.motor_angle),
+ torch.sin(state.pendulum_angle),
+ torch.cos(state.pendulum_angle),
+ state.motor_vel,
+ state.pendulum_vel,
+ ], dim=-1)
+
+ def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
+ # height: sin(θ) → -1 (down) to +1 (up)
+ height = torch.sin(state.pendulum_angle)
+
+ # Upright reward: strongly rewards being near vertical.
+ # Uses cos(θ - π/2) = sin(θ), squared and scaled so:
+ # down (h=-1): 0.0
+ # horiz (h= 0): 0.0
+ # up (h=+1): 1.0
+ # Only kicks in above horizontal, so swing-up isn't penalised.
+ upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
+
+ # Motor effort penalty: small cost to avoid bang-bang control.
+ effort_penalty = 0.001 * actions.squeeze(-1) ** 2
+
+ return 5.0 * upright_reward - effort_penalty
+
+ def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
+ # No early termination — episode runs for max_steps (truncation only).
+ # The agent must learn to swing up AND balance continuously.
+ return torch.zeros_like(state.motor_angle, dtype=torch.bool)
+
+ def get_default_qpos(self, nq: int) -> list[float] | None:
+ # The STL mesh is horizontal at qpos=0.
+ # Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
+ import math
+ return [0.0, -math.pi / 2]
diff --git a/src/models/mlp.py b/src/models/mlp.py
index ebc4537..398c6a5 100644
--- a/src/models/mlp.py
+++ b/src/models/mlp.py
@@ -4,7 +4,7 @@ from gymnasium import spaces
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
- def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
+ def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
DeterministicMixin.__init__(self, clip_actions)
diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py
index d91e54e..7cf7755 100644
--- a/src/runners/mujoco.py
+++ b/src/runners/mujoco.py
@@ -1,12 +1,11 @@
import dataclasses
-import tempfile
+import os
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
-import mujoco.viewer
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
+ action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
@@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
- # Step 1: Load URDF/MJCF as-is (no actuators)
- model_raw = mujoco.MjModel.from_xml_path(model_path)
+ abs_path = os.path.abspath(model_path)
+ model_dir = os.path.dirname(abs_path)
+ is_urdf = abs_path.lower().endswith(".urdf")
+
+ # MuJoCo's URDF parser strips directory prefixes from mesh filenames,
+ # so we inject a block into a
+ # temporary copy. The original URDF stays clean and simulator-agnostic.
+ if is_urdf:
+ tree = ET.parse(abs_path)
+ root = tree.getroot()
+ # Detect the mesh subdirectory from the first mesh filename
+ meshdir = None
+ for mesh_el in root.iter("mesh"):
+ fn = mesh_el.get("filename", "")
+ dirname = os.path.dirname(fn)
+ if dirname:
+ meshdir = dirname
+ break
+ if meshdir:
+ mj_ext = ET.SubElement(root, "mujoco")
+ ET.SubElement(mj_ext, "compiler", attrib={
+ "meshdir": meshdir,
+ "balanceinertia": "true",
+ })
+ tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
+ tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
+ try:
+ model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
+ finally:
+ os.unlink(tmp_urdf)
+ else:
+ model_raw = mujoco.MjModel.from_xml_path(abs_path)
if not actuators:
return model_raw
- # Step 2: Export internal MJCF representation
- tmp_mjcf = tempfile.mktemp(suffix=".xml")
+ # Step 2: Export internal MJCF representation (save next to original
+ # model so relative mesh/asset paths resolve correctly on reload)
+ tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
+
+ # Step 3: Inject actuators into the MJCF XML
+ # Use torque actuator. Speed is limited by joint damping:
+ # at steady state, vel_max = torque / damping.
+ root = ET.fromstring(mjcf_str)
+ act_elem = ET.SubElement(root, "actuator")
+ for act in actuators:
+ ET.SubElement(act_elem, "motor", attrib={
+ "name": f"{act.joint}_motor",
+ "joint": act.joint,
+ "gear": str(act.gear),
+ "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
+ })
+
+ # Add damping to actuated joints to limit max speed and
+ # mimic real motor friction / back-EMF.
+ # vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
+ actuated_joints = {a.joint for a in actuators}
+ for body in root.iter("body"):
+ for jnt in body.findall("joint"):
+ if jnt.get("name") in actuated_joints:
+ jnt.set("damping", "0.05")
+
+ # Disable self-collision on all geoms.
+ # URDF mesh convex hulls often overlap at joints (especially
+ # grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
+ # causing phantom contact forces.
+ for geom in root.iter("geom"):
+ geom.set("contype", "0")
+ geom.set("conaffinity", "0")
+
+ # Step 4: Write modified MJCF and reload from file path
+ # (from_xml_path resolves mesh paths relative to the file location)
+ modified_xml = ET.tostring(root, encoding="unicode")
+ with open(tmp_mjcf, "w") as f:
+ f.write(modified_xml)
+ return mujoco.MjModel.from_xml_path(tmp_mjcf)
finally:
- import os
- os.unlink(tmp_mjcf)
-
- # Step 3: Inject actuators into the MJCF XML
- root = ET.fromstring(mjcf_str)
- act_elem = ET.SubElement(root, "actuator")
- for act in actuators:
- ET.SubElement(act_elem, "motor", attrib={
- "name": f"{act.joint}_motor",
- "joint": act.joint,
- "gear": str(act.gear),
- "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
- })
-
- # Step 4: Reload from modified MJCF
- modified_xml = ET.tostring(root, encoding="unicode")
- return mujoco.MjModel.from_xml_string(modified_xml)
+ if os.path.exists(tmp_mjcf):
+ os.unlink(tmp_mjcf)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._nq = self._model.nq
self._nv = self._model.nv
+ # Per-env smoothed ctrl state for EMA action filtering.
+ # Models real motor inertia: ctrl can't reverse instantly.
+ nu = self._model.nu
+ self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
+
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
+ alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
- data.ctrl[:] = actions_np[i]
+ # EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
+ self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
+ data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
+ default_qpos = self.env.get_default_qpos(self._nq)
+
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
- # Add small random perturbation so the pole doesn't start perfectly upright
+ # Set initial pose (env-specific, e.g. pendulum hanging down)
+ if default_qpos is not None:
+ data.qpos[:] = default_qpos
+
+ # Add small random perturbation for exploration
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
+ # Reset smoothed ctrl so motor starts from rest
+ self._smooth_ctrl[env_id][:] = 0.0
+
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
)
def _sim_close(self) -> None:
- if hasattr(self, "_viewer") and self._viewer is not None:
- self._viewer.close()
- self._viewer = None
-
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
-
self._data.clear()
- def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
- if mode == "human":
- if not hasattr(self, "_viewer") or self._viewer is None:
- self._viewer = mujoco.viewer.launch_passive(
- self._model, self._data[env_idx]
- )
- # Update visual geometry from current physics state
- mujoco.mj_forward(self._model, self._data[env_idx])
- self._viewer.sync()
- return None
- elif mode == "rgb_array":
- # Cache the offscreen renderer to avoid create/destroy overhead
- if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
- self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
- self._offscreen_renderer.update_scene(self._data[env_idx])
- pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
- return torch.from_numpy(pixels)
\ No newline at end of file
+ def render(self, env_idx: int = 0) -> np.ndarray | None:
+ """Offscreen render → RGB numpy array (H, W, 3)."""
+ if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
+ self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
+ self._offscreen_renderer.update_scene(self._data[env_idx])
+ return self._offscreen_renderer.render().copy()
\ No newline at end of file
diff --git a/src/training/trainer.py b/src/training/trainer.py
index 5d5a18e..3315156 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -4,19 +4,22 @@ import tempfile
from pathlib import Path
import numpy as np
+import torch
import tqdm
+from clearml import Logger
+from gymnasium import spaces
+from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
+from skrl.memories.torch import RandomMemory
+from skrl.resources.preprocessors.torch import RunningStandardScaler
+from skrl.trainers.torch import SequentialTrainer
from src.core.runner import BaseRunner
-from clearml import Task, Logger
-import torch
-from gymnasium import spaces
-from skrl.memories.torch import RandomMemory
from src.models.mlp import SharedMLP
-from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
-from skrl.trainers.torch import SequentialTrainer
+
@dataclasses.dataclass
class TrainerConfig:
+ # PPO
rollout_steps: int = 2048
learning_epochs: int = 8
mini_batches: int = 4
@@ -29,30 +32,27 @@ class TrainerConfig:
hidden_sizes: tuple[int, ...] = (64, 64)
+ # Training
total_timesteps: int = 1_000_000
log_interval: int = 10
- # Video recording
- record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
- record_video_min_seconds: float = 10.0 # minimum video duration in seconds
- record_video_fps: int = 0 # 0 = auto-derive from simulation rate
+ # Video recording (uploaded to ClearML)
+ record_video_every: int = 10_000 # 0 = disabled
+ record_video_fps: int = 0 # 0 = derive from sim dt×substeps
- clearml_project: str | None = None
- clearml_task: str | None = None
+# ── Video-recording trainer ──────────────────────────────────────────
class VideoRecordingTrainer(SequentialTrainer):
- """Subclass of skrl's SequentialTrainer that records videos periodically."""
+ """SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
super().__init__(env=env, agents=agents, cfg=cfg)
- self._trainer_config = trainer_config
+ self._tcfg = trainer_config
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
def single_agent_train(self) -> None:
- """Override to add periodic video recording."""
- assert self.num_simultaneous_agents == 1
- assert self.env.num_agents == 1
+ assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
states, infos = self.env.reset()
@@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer):
disable=self.disable_progressbar,
file=sys.stdout,
):
- # Pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
with torch.no_grad():
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
- if not self.headless:
- self.env.render()
-
self.agents.record_transition(
- states=states,
- actions=actions,
- rewards=rewards,
- next_states=next_states,
- terminated=terminated,
- truncated=truncated,
- infos=infos,
- timestep=timestep,
- timesteps=self.timesteps,
+ states=states, actions=actions, rewards=rewards,
+ next_states=next_states, terminated=terminated,
+ truncated=truncated, infos=infos,
+ timestep=timestep, timesteps=self.timesteps,
)
if self.environment_info in infos:
@@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer):
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
- # Reset environments
+ # Auto-reset for multi-env; single-env resets manually
if self.env.num_envs > 1:
states = next_states
else:
@@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer):
else:
states = next_states
- # Record video at intervals
- cfg = self._trainer_config
+ # Periodic video recording
if (
- cfg
- and cfg.record_video_every > 0
- and (timestep + 1) % cfg.record_video_every == 0
+ self._tcfg
+ and self._tcfg.record_video_every > 0
+ and (timestep + 1) % self._tcfg.record_video_every == 0
):
self._record_video(timestep + 1)
- def _get_video_fps(self) -> int:
- """Derive video fps from the simulation rate, or use configured value."""
- cfg = self._trainer_config
- if cfg.record_video_fps > 0:
- return cfg.record_video_fps
- # Auto-derive from runner's simulation parameters
- runner = self.env
- dt = getattr(runner.config, "dt", 0.02)
- substeps = getattr(runner.config, "substeps", 1)
+ # ── helpers ───────────────────────────────────────────────────────
+
+ def _get_fps(self) -> int:
+ if self._tcfg and self._tcfg.record_video_fps > 0:
+ return self._tcfg.record_video_fps
+ dt = getattr(self.env.config, "dt", 0.02)
+ substeps = getattr(self.env.config, "substeps", 1)
return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None:
- """Record evaluation episodes and upload to ClearML."""
try:
import imageio.v3 as iio
except ImportError:
- try:
- import imageio as iio
- except ImportError:
- return
+ return
- cfg = self._trainer_config
- fps = self._get_video_fps()
- min_frames = int(cfg.record_video_min_seconds * fps)
- max_frames = min_frames * 3 # hard cap to prevent runaway recording
+ fps = self._get_fps()
+ max_steps = getattr(self.env.env.config, "max_steps", 500)
frames: list[np.ndarray] = []
- while len(frames) < min_frames and len(frames) < max_frames:
- obs, _ = self.env.reset()
- done = False
- steps = 0
- max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
- while not done and steps < max_episode_steps:
- with torch.no_grad():
- action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
- obs, _, terminated, truncated, _ = self.env.step(action)
- frame = self.env.render(mode="rgb_array")
- if frame is not None:
- frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
- done = (terminated | truncated).any().item()
- steps += 1
- if len(frames) >= max_frames:
- break
+ obs, _ = self.env.reset()
+ for _ in range(max_steps):
+ with torch.no_grad():
+ action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
+ obs, _, terminated, truncated, _ = self.env.step(action)
+
+ frame = self.env.render()
+ if frame is not None:
+ frames.append(frame)
+
+ if (terminated | truncated).any().item():
+ break
if frames:
- video_path = str(self._video_dir / f"step_{timestep}.mp4")
- iio.imwrite(video_path, frames, fps=fps)
+ path = str(self._video_dir / f"step_{timestep}.mp4")
+ iio.imwrite(path, frames, fps=fps)
logger = Logger.current_logger()
if logger:
logger.report_media(
- title="Training Video",
- series=f"step_{timestep}",
- local_path=video_path,
- iteration=timestep,
+ "Training Video", f"step_{timestep}",
+ local_path=path, iteration=timestep,
)
- # Reset back to training state after recording
+ # Restore training state
self.env.reset()
+
+# ── Main trainer ─────────────────────────────────────────────────────
+
class Trainer:
def __init__(self, runner: BaseRunner, config: TrainerConfig):
self.runner = runner
self.config = config
-
- self._init_clearml()
self._init_agent()
- def _init_clearml(self) -> None:
- if self.config.clearml_project and self.config.clearml_task:
- self.clearml_task = Task.init(
- project_name=self.config.clearml_project,
- task_name=self.config.clearml_task,
- )
- else:
- self.clearml_task = None
-
def _init_agent(self) -> None:
- device: torch.device = self.runner.device
- obs_space: spaces.Space = self.runner.observation_space
- act_space: spaces.Space = self.runner.action_space
- num_envs: int = self.runner.num_envs
+ device = self.runner.device
+ obs_space = self.runner.observation_space
+ act_space = self.runner.action_space
- self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
+ self.memory = RandomMemory(
+ memory_size=self.config.rollout_steps,
+ num_envs=self.runner.num_envs,
+ device=device,
+ )
- self.model: SharedMLP = SharedMLP(
+ self.model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=self.config.hidden_sizes,
+ initial_log_std=0.5,
+ min_log_std=-2.0,
)
- models = {
- "policy": self.model,
- "value": self.model,
- }
+ models = {"policy": self.model, "value": self.model}
agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg.update({
@@ -217,9 +187,19 @@ class Trainer:
"ratio_clip": self.config.clip_ratio,
"value_loss_scale": self.config.value_loss_scale,
"entropy_loss_scale": self.config.entropy_loss_scale,
+ "state_preprocessor": RunningStandardScaler,
+ "state_preprocessor_kwargs": {"size": obs_space, "device": device},
+ "value_preprocessor": RunningStandardScaler,
+ "value_preprocessor_kwargs": {"size": 1, "device": device},
})
+ # Wire up logging frequency: write_interval is in timesteps.
+ # log_interval=1 → log every PPO update (= every rollout_steps timesteps).
+ agent_cfg["experiment"]["write_interval"] = self.config.log_interval
+ agent_cfg["experiment"]["checkpoint_interval"] = max(
+ self.config.total_timesteps // 10, self.config.rollout_steps
+ )
- self.agent: PPO = PPO(
+ self.agent = PPO(
models=models,
memory=self.memory,
observation_space=obs_space,
@@ -238,6 +218,4 @@ class Trainer:
trainer.train()
def close(self) -> None:
- self.runner.close()
- if self.clearml_task:
- self.clearml_task.close()
\ No newline at end of file
+ self.runner.close()
\ No newline at end of file
diff --git a/train.py b/train.py
index aa40206..b133bbc 100644
--- a/train.py
+++ b/train.py
@@ -1,39 +1,80 @@
import hydra
+from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
+from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
+from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
-from src.core.env import ActuatorConfig
+
+# ── env registry ──────────────────────────────────────────────────────
+# Maps Hydra config-group name → (EnvClass, ConfigClass)
+ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
+ "cartpole": (CartPoleEnv, CartPoleConfig),
+ "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
+}
-def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
+def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
+ """Instantiate the right env + config from the Hydra config-group name."""
+ if env_name not in ENV_REGISTRY:
+ raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
+
+ env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
+
+ # Convert actuator dicts → ActuatorConfig objects
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
- return CartPoleConfig(**env_dict)
+
+ return env_cls(config_cls(**env_dict))
+
+
+def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
+ """Initialize ClearML task with project structure and tags.
+
+ Project: RL-Trainings/ (e.g. RL-Trainings/Rotary Cartpole)
+ Tags: env, runner, training algo choices from Hydra.
+ """
+ Task.ignore_requirements("torch")
+
+ env_name = choices.get("env", "cartpole")
+ runner_name = choices.get("runner", "mujoco")
+ training_name = choices.get("training", "ppo")
+
+ project = "RL-Framework"
+ task_name = f"{env_name}-{runner_name}-{training_name}"
+ tags = [env_name, runner_name, training_name]
+
+ task = Task.init(project_name=project, task_name=task_name, tags=tags)
+
+ if remote:
+ task.execute_remotely(queue_name="default")
+
+ return task
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
- env_config = _build_env_config(cfg)
- runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
+ choices = HydraConfig.get().runtime.choices
+ # ClearML init — must happen before heavy work so remote execution
+ # can take over early. The remote worker re-runs the full script;
+ # execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
- # Build ClearML task name dynamically from Hydra config group choices
- if not training_dict.get("clearml_task"):
- choices = HydraConfig.get().runtime.choices
- env_name = choices.get("env", "env")
- runner_name = choices.get("runner", "runner")
- training_name = choices.get("training", "algo")
- training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
+ remote = training_dict.pop("remote", False)
+ task = _init_clearml(choices, remote=remote)
+
+ env_name = choices.get("env", "cartpole")
+ env = _build_env(env_name, cfg)
+ runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
trainer_config = TrainerConfig(**training_dict)
- env = CartPoleEnv(env_config)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config)
@@ -41,6 +82,7 @@ def main(cfg: DictConfig) -> None:
trainer.train()
finally:
trainer.close()
+ task.close()
if __name__ == "__main__":