diff --git a/assets/rotary_cartpole/meshes/arm_1.stl b/assets/rotary_cartpole/meshes/arm_1.stl new file mode 100644 index 0000000..2cfeba8 Binary files /dev/null and b/assets/rotary_cartpole/meshes/arm_1.stl differ diff --git a/assets/rotary_cartpole/meshes/base_link.stl b/assets/rotary_cartpole/meshes/base_link.stl new file mode 100644 index 0000000..d56b548 Binary files /dev/null and b/assets/rotary_cartpole/meshes/base_link.stl differ diff --git a/assets/rotary_cartpole/meshes/pendulum_1.stl b/assets/rotary_cartpole/meshes/pendulum_1.stl new file mode 100644 index 0000000..73d39ac Binary files /dev/null and b/assets/rotary_cartpole/meshes/pendulum_1.stl differ diff --git a/assets/rotary_cartpole/rotary_cartpole.urdf b/assets/rotary_cartpole/rotary_cartpole.urdf new file mode 100644 index 0000000..b68ed3b --- /dev/null +++ b/assets/rotary_cartpole/rotary_cartpole.urdf @@ -0,0 +1,105 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/configs/env/rotary_cartpole.yaml b/configs/env/rotary_cartpole.yaml new file mode 100644 index 0000000..c4fb61c --- /dev/null +++ b/configs/env/rotary_cartpole.yaml @@ -0,0 +1,7 @@ +max_steps: 1000 +model_path: assets/rotary_cartpole/rotary_cartpole.urdf +reward_upright_scale: 1.0 +actuators: + - joint: motor_joint + gear: 15.0 + ctrl_range: [-1.0, 1.0] diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml index 42a54b5..861b8ba 100644 --- a/configs/runner/mujoco.yaml +++ b/configs/runner/mujoco.yaml @@ -1,4 +1,5 @@ -num_envs: 16 +num_envs: 64 device: cpu -dt: 0.02 -substeps: 2 +dt: 0.002 +substeps: 20 +action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse) diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml index 8025216..c078de6 100644 --- a/configs/training/ppo.yaml +++ b/configs/training/ppo.yaml @@ -1,5 +1,5 @@ hidden_sizes: [128, 128] -total_timesteps: 1000000 +total_timesteps: 5000000 rollout_steps: 1024 learning_epochs: 4 mini_batches: 4 @@ -8,6 +8,8 @@ gae_lambda: 0.95 learning_rate: 0.0003 clip_ratio: 0.2 value_loss_scale: 0.5 -entropy_loss_scale: 0.01 +entropy_loss_scale: 0.05 log_interval: 10 -clearml_project: RL-Framework + +# ClearML remote execution (GPU worker) +remote: false diff --git a/requirements.txt b/requirements.txt index f40d2e8..f3ed7df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ omegaconf mujoco skrl[torch] clearml +imageio +imageio-ffmpeg pytest \ No newline at end of file diff --git a/src/core/env.py b/src/core/env.py index 2654898..80c5d96 100644 --- a/src/core/env.py +++ b/src/core/env.py @@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]): def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor: return step_counts >= self.config.max_steps + + def get_default_qpos(self, nq: int) -> list[float] | None: + """Return the default joint positions for reset. + Override in subclass if the URDF zero pose doesn't match + the desired initial state (e.g. pendulum hanging down). + Returns None to use the URDF default (all zeros).""" + return None diff --git a/src/core/runner.py b/src/core/runner.py index f3e17fa..ff00890 100644 --- a/src/core/runner.py +++ b/src/core/runner.py @@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]): # skrl expects (num_envs, 1) for rewards/terminated/truncated return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info - def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None: - raise NotImplementedError("Render method not implemented for this runner.") + def render(self, env_idx: int = 0): + """Offscreen render → RGB numpy array. Override in subclass.""" + raise NotImplementedError("Render not implemented for this runner.") def close(self) -> None: self._sim_close() \ No newline at end of file diff --git a/src/envs/rotary_cartpole.py b/src/envs/rotary_cartpole.py new file mode 100644 index 0000000..d70dcf2 --- /dev/null +++ b/src/envs/rotary_cartpole.py @@ -0,0 +1,94 @@ +import dataclasses +import torch +from src.core.env import BaseEnv, BaseEnvConfig +from gymnasium import spaces + + +@dataclasses.dataclass +class RotaryCartPoleState: + motor_angle: torch.Tensor # (num_envs,) + motor_vel: torch.Tensor # (num_envs,) + pendulum_angle: torch.Tensor # (num_envs,) + pendulum_vel: torch.Tensor # (num_envs,) + + +@dataclasses.dataclass +class RotaryCartPoleConfig(BaseEnvConfig): + """Rotary inverted pendulum (Furuta pendulum) task config. + + The motor rotates the arm horizontally; the pendulum swings freely + at the arm tip. Goal: swing the pendulum up and balance it upright. + """ + # Reward shaping + reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1 + + +class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]): + """Furuta pendulum / rotary inverted pendulum environment. + + Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum + + Observations (6): + [sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel] + Using sin/cos avoids discontinuities at ±π for continuous joints. + + Actions (1): + Torque applied to the motor_joint. + """ + + def __init__(self, config: RotaryCartPoleConfig): + super().__init__(config) + + @property + def observation_space(self) -> spaces.Space: + return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,)) + + @property + def action_space(self) -> spaces.Space: + return spaces.Box(low=-1.0, high=1.0, shape=(1,)) + + def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState: + return RotaryCartPoleState( + motor_angle=qpos[:, 0], + motor_vel=qvel[:, 0], + pendulum_angle=qpos[:, 1], + pendulum_vel=qvel[:, 1], + ) + + def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor: + return torch.stack([ + torch.sin(state.motor_angle), + torch.cos(state.motor_angle), + torch.sin(state.pendulum_angle), + torch.cos(state.pendulum_angle), + state.motor_vel, + state.pendulum_vel, + ], dim=-1) + + def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor: + # height: sin(θ) → -1 (down) to +1 (up) + height = torch.sin(state.pendulum_angle) + + # Upright reward: strongly rewards being near vertical. + # Uses cos(θ - π/2) = sin(θ), squared and scaled so: + # down (h=-1): 0.0 + # horiz (h= 0): 0.0 + # up (h=+1): 1.0 + # Only kicks in above horizontal, so swing-up isn't penalised. + upright_reward = torch.clamp(height, 0.0, 1.0) ** 2 + + # Motor effort penalty: small cost to avoid bang-bang control. + effort_penalty = 0.001 * actions.squeeze(-1) ** 2 + + return 5.0 * upright_reward - effort_penalty + + def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor: + # No early termination — episode runs for max_steps (truncation only). + # The agent must learn to swing up AND balance continuously. + return torch.zeros_like(state.motor_angle, dtype=torch.bool) + + def get_default_qpos(self, nq: int) -> list[float] | None: + # The STL mesh is horizontal at qpos=0. + # Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1). + import math + return [0.0, -math.pi / 2] diff --git a/src/models/mlp.py b/src/models/mlp.py index ebc4537..398c6a5 100644 --- a/src/models/mlp.py +++ b/src/models/mlp.py @@ -4,7 +4,7 @@ from gymnasium import spaces from skrl.models.torch import Model, GaussianMixin, DeterministicMixin class SharedMLP(GaussianMixin, DeterministicMixin, Model): - def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0): + def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0): Model.__init__(self, observation_space, action_space, device) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) DeterministicMixin.__init__(self, clip_actions) diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index d91e54e..7cf7755 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -1,12 +1,11 @@ import dataclasses -import tempfile +import os import xml.etree.ElementTree as ET from src.core.env import BaseEnv, ActuatorConfig from src.core.runner import BaseRunner, BaseRunnerConfig import torch import numpy as np import mujoco -import mujoco.viewer @dataclasses.dataclass class MuJoCoRunnerConfig(BaseRunnerConfig): @@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig): device: str = "cpu" dt: float = 0.02 substeps: int = 2 + action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant) class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig): @@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): This keeps the URDF clean and standard — actuator config lives in the env config (Isaac Lab pattern), not in the robot file. """ - # Step 1: Load URDF/MJCF as-is (no actuators) - model_raw = mujoco.MjModel.from_xml_path(model_path) + abs_path = os.path.abspath(model_path) + model_dir = os.path.dirname(abs_path) + is_urdf = abs_path.lower().endswith(".urdf") + + # MuJoCo's URDF parser strips directory prefixes from mesh filenames, + # so we inject a block into a + # temporary copy. The original URDF stays clean and simulator-agnostic. + if is_urdf: + tree = ET.parse(abs_path) + root = tree.getroot() + # Detect the mesh subdirectory from the first mesh filename + meshdir = None + for mesh_el in root.iter("mesh"): + fn = mesh_el.get("filename", "") + dirname = os.path.dirname(fn) + if dirname: + meshdir = dirname + break + if meshdir: + mj_ext = ET.SubElement(root, "mujoco") + ET.SubElement(mj_ext, "compiler", attrib={ + "meshdir": meshdir, + "balanceinertia": "true", + }) + tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf") + tree.write(tmp_urdf, xml_declaration=True, encoding="unicode") + try: + model_raw = mujoco.MjModel.from_xml_path(tmp_urdf) + finally: + os.unlink(tmp_urdf) + else: + model_raw = mujoco.MjModel.from_xml_path(abs_path) if not actuators: return model_raw - # Step 2: Export internal MJCF representation - tmp_mjcf = tempfile.mktemp(suffix=".xml") + # Step 2: Export internal MJCF representation (save next to original + # model so relative mesh/asset paths resolve correctly on reload) + tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml") try: mujoco.mj_saveLastXML(tmp_mjcf, model_raw) with open(tmp_mjcf) as f: mjcf_str = f.read() + + # Step 3: Inject actuators into the MJCF XML + # Use torque actuator. Speed is limited by joint damping: + # at steady state, vel_max = torque / damping. + root = ET.fromstring(mjcf_str) + act_elem = ET.SubElement(root, "actuator") + for act in actuators: + ET.SubElement(act_elem, "motor", attrib={ + "name": f"{act.joint}_motor", + "joint": act.joint, + "gear": str(act.gear), + "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", + }) + + # Add damping to actuated joints to limit max speed and + # mimic real motor friction / back-EMF. + # vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s) + actuated_joints = {a.joint for a in actuators} + for body in root.iter("body"): + for jnt in body.findall("joint"): + if jnt.get("name") in actuated_joints: + jnt.set("damping", "0.05") + + # Disable self-collision on all geoms. + # URDF mesh convex hulls often overlap at joints (especially + # grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude), + # causing phantom contact forces. + for geom in root.iter("geom"): + geom.set("contype", "0") + geom.set("conaffinity", "0") + + # Step 4: Write modified MJCF and reload from file path + # (from_xml_path resolves mesh paths relative to the file location) + modified_xml = ET.tostring(root, encoding="unicode") + with open(tmp_mjcf, "w") as f: + f.write(modified_xml) + return mujoco.MjModel.from_xml_path(tmp_mjcf) finally: - import os - os.unlink(tmp_mjcf) - - # Step 3: Inject actuators into the MJCF XML - root = ET.fromstring(mjcf_str) - act_elem = ET.SubElement(root, "actuator") - for act in actuators: - ET.SubElement(act_elem, "motor", attrib={ - "name": f"{act.joint}_motor", - "joint": act.joint, - "gear": str(act.gear), - "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", - }) - - # Step 4: Reload from modified MJCF - modified_xml = ET.tostring(root, encoding="unicode") - return mujoco.MjModel.from_xml_string(modified_xml) + if os.path.exists(tmp_mjcf): + os.unlink(tmp_mjcf) def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: model_path = self.env.config.model_path @@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): self._nq = self._model.nq self._nv = self._model.nv + # Per-env smoothed ctrl state for EMA action filtering. + # Models real motor inertia: ctrl can't reverse instantly. + nu = self._model.nu + self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)] + def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: actions_np: np.ndarray = actions.cpu().numpy() + alpha = self.config.action_ema_alpha qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32) qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32) for i, data in enumerate(self._data): - data.ctrl[:] = actions_np[i] + # EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl + self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i] + data.ctrl[:] = self._smooth_ctrl[i] for _ in range(self.config.substeps): mujoco.mj_step(self._model, data) @@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): qpos_batch = np.zeros((n, self._nq), dtype=np.float32) qvel_batch = np.zeros((n, self._nv), dtype=np.float32) + default_qpos = self.env.get_default_qpos(self._nq) + for i, env_id in enumerate(ids): data = self._data[env_id] mujoco.mj_resetData(self._model, data) - # Add small random perturbation so the pole doesn't start perfectly upright + # Set initial pose (env-specific, e.g. pendulum hanging down) + if default_qpos is not None: + data.qpos[:] = default_qpos + + # Add small random perturbation for exploration data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq) data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv) + # Reset smoothed ctrl so motor starts from rest + self._smooth_ctrl[env_id][:] = 0.0 + qpos_batch[i] = data.qpos qvel_batch[i] = data.qvel @@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): ) def _sim_close(self) -> None: - if hasattr(self, "_viewer") and self._viewer is not None: - self._viewer.close() - self._viewer = None - if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: self._offscreen_renderer.close() self._offscreen_renderer = None - self._data.clear() - def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None: - if mode == "human": - if not hasattr(self, "_viewer") or self._viewer is None: - self._viewer = mujoco.viewer.launch_passive( - self._model, self._data[env_idx] - ) - # Update visual geometry from current physics state - mujoco.mj_forward(self._model, self._data[env_idx]) - self._viewer.sync() - return None - elif mode == "rgb_array": - # Cache the offscreen renderer to avoid create/destroy overhead - if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None: - self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640) - self._offscreen_renderer.update_scene(self._data[env_idx]) - pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused - return torch.from_numpy(pixels) \ No newline at end of file + def render(self, env_idx: int = 0) -> np.ndarray | None: + """Offscreen render → RGB numpy array (H, W, 3).""" + if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None: + self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640) + self._offscreen_renderer.update_scene(self._data[env_idx]) + return self._offscreen_renderer.render().copy() \ No newline at end of file diff --git a/src/training/trainer.py b/src/training/trainer.py index 5d5a18e..3315156 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -4,19 +4,22 @@ import tempfile from pathlib import Path import numpy as np +import torch import tqdm +from clearml import Logger +from gymnasium import spaces +from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG +from skrl.memories.torch import RandomMemory +from skrl.resources.preprocessors.torch import RunningStandardScaler +from skrl.trainers.torch import SequentialTrainer from src.core.runner import BaseRunner -from clearml import Task, Logger -import torch -from gymnasium import spaces -from skrl.memories.torch import RandomMemory from src.models.mlp import SharedMLP -from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG -from skrl.trainers.torch import SequentialTrainer + @dataclasses.dataclass class TrainerConfig: + # PPO rollout_steps: int = 2048 learning_epochs: int = 8 mini_batches: int = 4 @@ -29,30 +32,27 @@ class TrainerConfig: hidden_sizes: tuple[int, ...] = (64, 64) + # Training total_timesteps: int = 1_000_000 log_interval: int = 10 - # Video recording - record_video_every: int = 10000 # record a video every N timesteps (0 = disabled) - record_video_min_seconds: float = 10.0 # minimum video duration in seconds - record_video_fps: int = 0 # 0 = auto-derive from simulation rate + # Video recording (uploaded to ClearML) + record_video_every: int = 10_000 # 0 = disabled + record_video_fps: int = 0 # 0 = derive from sim dt×substeps - clearml_project: str | None = None - clearml_task: str | None = None +# ── Video-recording trainer ────────────────────────────────────────── class VideoRecordingTrainer(SequentialTrainer): - """Subclass of skrl's SequentialTrainer that records videos periodically.""" + """SequentialTrainer with periodic evaluation videos uploaded to ClearML.""" def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None): super().__init__(env=env, agents=agents, cfg=cfg) - self._trainer_config = trainer_config + self._tcfg = trainer_config self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_")) def single_agent_train(self) -> None: - """Override to add periodic video recording.""" - assert self.num_simultaneous_agents == 1 - assert self.env.num_agents == 1 + assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1 states, infos = self.env.reset() @@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer): disable=self.disable_progressbar, file=sys.stdout, ): - # Pre-interaction self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps) with torch.no_grad(): actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] next_states, rewards, terminated, truncated, infos = self.env.step(actions) - if not self.headless: - self.env.render() - self.agents.record_transition( - states=states, - actions=actions, - rewards=rewards, - next_states=next_states, - terminated=terminated, - truncated=truncated, - infos=infos, - timestep=timestep, - timesteps=self.timesteps, + states=states, actions=actions, rewards=rewards, + next_states=next_states, terminated=terminated, + truncated=truncated, infos=infos, + timestep=timestep, timesteps=self.timesteps, ) if self.environment_info in infos: @@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer): self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps) - # Reset environments + # Auto-reset for multi-env; single-env resets manually if self.env.num_envs > 1: states = next_states else: @@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer): else: states = next_states - # Record video at intervals - cfg = self._trainer_config + # Periodic video recording if ( - cfg - and cfg.record_video_every > 0 - and (timestep + 1) % cfg.record_video_every == 0 + self._tcfg + and self._tcfg.record_video_every > 0 + and (timestep + 1) % self._tcfg.record_video_every == 0 ): self._record_video(timestep + 1) - def _get_video_fps(self) -> int: - """Derive video fps from the simulation rate, or use configured value.""" - cfg = self._trainer_config - if cfg.record_video_fps > 0: - return cfg.record_video_fps - # Auto-derive from runner's simulation parameters - runner = self.env - dt = getattr(runner.config, "dt", 0.02) - substeps = getattr(runner.config, "substeps", 1) + # ── helpers ─────────────────────────────────────────────────────── + + def _get_fps(self) -> int: + if self._tcfg and self._tcfg.record_video_fps > 0: + return self._tcfg.record_video_fps + dt = getattr(self.env.config, "dt", 0.02) + substeps = getattr(self.env.config, "substeps", 1) return max(1, int(round(1.0 / (dt * substeps)))) def _record_video(self, timestep: int) -> None: - """Record evaluation episodes and upload to ClearML.""" try: import imageio.v3 as iio except ImportError: - try: - import imageio as iio - except ImportError: - return + return - cfg = self._trainer_config - fps = self._get_video_fps() - min_frames = int(cfg.record_video_min_seconds * fps) - max_frames = min_frames * 3 # hard cap to prevent runaway recording + fps = self._get_fps() + max_steps = getattr(self.env.env.config, "max_steps", 500) frames: list[np.ndarray] = [] - while len(frames) < min_frames and len(frames) < max_frames: - obs, _ = self.env.reset() - done = False - steps = 0 - max_episode_steps = getattr(self.env.env.config, "max_steps", 500) - while not done and steps < max_episode_steps: - with torch.no_grad(): - action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] - obs, _, terminated, truncated, _ = self.env.step(action) - frame = self.env.render(mode="rgb_array") - if frame is not None: - frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame) - done = (terminated | truncated).any().item() - steps += 1 - if len(frames) >= max_frames: - break + obs, _ = self.env.reset() + for _ in range(max_steps): + with torch.no_grad(): + action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] + obs, _, terminated, truncated, _ = self.env.step(action) + + frame = self.env.render() + if frame is not None: + frames.append(frame) + + if (terminated | truncated).any().item(): + break if frames: - video_path = str(self._video_dir / f"step_{timestep}.mp4") - iio.imwrite(video_path, frames, fps=fps) + path = str(self._video_dir / f"step_{timestep}.mp4") + iio.imwrite(path, frames, fps=fps) logger = Logger.current_logger() if logger: logger.report_media( - title="Training Video", - series=f"step_{timestep}", - local_path=video_path, - iteration=timestep, + "Training Video", f"step_{timestep}", + local_path=path, iteration=timestep, ) - # Reset back to training state after recording + # Restore training state self.env.reset() + +# ── Main trainer ───────────────────────────────────────────────────── + class Trainer: def __init__(self, runner: BaseRunner, config: TrainerConfig): self.runner = runner self.config = config - - self._init_clearml() self._init_agent() - def _init_clearml(self) -> None: - if self.config.clearml_project and self.config.clearml_task: - self.clearml_task = Task.init( - project_name=self.config.clearml_project, - task_name=self.config.clearml_task, - ) - else: - self.clearml_task = None - def _init_agent(self) -> None: - device: torch.device = self.runner.device - obs_space: spaces.Space = self.runner.observation_space - act_space: spaces.Space = self.runner.action_space - num_envs: int = self.runner.num_envs + device = self.runner.device + obs_space = self.runner.observation_space + act_space = self.runner.action_space - self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device) + self.memory = RandomMemory( + memory_size=self.config.rollout_steps, + num_envs=self.runner.num_envs, + device=device, + ) - self.model: SharedMLP = SharedMLP( + self.model = SharedMLP( observation_space=obs_space, action_space=act_space, device=device, hidden_sizes=self.config.hidden_sizes, + initial_log_std=0.5, + min_log_std=-2.0, ) - models = { - "policy": self.model, - "value": self.model, - } + models = {"policy": self.model, "value": self.model} agent_cfg = PPO_DEFAULT_CONFIG.copy() agent_cfg.update({ @@ -217,9 +187,19 @@ class Trainer: "ratio_clip": self.config.clip_ratio, "value_loss_scale": self.config.value_loss_scale, "entropy_loss_scale": self.config.entropy_loss_scale, + "state_preprocessor": RunningStandardScaler, + "state_preprocessor_kwargs": {"size": obs_space, "device": device}, + "value_preprocessor": RunningStandardScaler, + "value_preprocessor_kwargs": {"size": 1, "device": device}, }) + # Wire up logging frequency: write_interval is in timesteps. + # log_interval=1 → log every PPO update (= every rollout_steps timesteps). + agent_cfg["experiment"]["write_interval"] = self.config.log_interval + agent_cfg["experiment"]["checkpoint_interval"] = max( + self.config.total_timesteps // 10, self.config.rollout_steps + ) - self.agent: PPO = PPO( + self.agent = PPO( models=models, memory=self.memory, observation_space=obs_space, @@ -238,6 +218,4 @@ class Trainer: trainer.train() def close(self) -> None: - self.runner.close() - if self.clearml_task: - self.clearml_task.close() \ No newline at end of file + self.runner.close() \ No newline at end of file diff --git a/train.py b/train.py index aa40206..b133bbc 100644 --- a/train.py +++ b/train.py @@ -1,39 +1,80 @@ import hydra +from clearml import Task from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf +from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig from src.envs.cartpole import CartPoleEnv, CartPoleConfig +from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig from src.training.trainer import Trainer, TrainerConfig -from src.core.env import ActuatorConfig + +# ── env registry ────────────────────────────────────────────────────── +# Maps Hydra config-group name → (EnvClass, ConfigClass) +ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { + "cartpole": (CartPoleEnv, CartPoleConfig), + "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), +} -def _build_env_config(cfg: DictConfig) -> CartPoleConfig: +def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: + """Instantiate the right env + config from the Hydra config-group name.""" + if env_name not in ENV_REGISTRY: + raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") + + env_cls, config_cls = ENV_REGISTRY[env_name] env_dict = OmegaConf.to_container(cfg.env, resolve=True) + + # Convert actuator dicts → ActuatorConfig objects if "actuators" in env_dict: for a in env_dict["actuators"]: if "ctrl_range" in a: a["ctrl_range"] = tuple(a["ctrl_range"]) env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] - return CartPoleConfig(**env_dict) + + return env_cls(config_cls(**env_dict)) + + +def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: + """Initialize ClearML task with project structure and tags. + + Project: RL-Trainings/ (e.g. RL-Trainings/Rotary Cartpole) + Tags: env, runner, training algo choices from Hydra. + """ + Task.ignore_requirements("torch") + + env_name = choices.get("env", "cartpole") + runner_name = choices.get("runner", "mujoco") + training_name = choices.get("training", "ppo") + + project = "RL-Framework" + task_name = f"{env_name}-{runner_name}-{training_name}" + tags = [env_name, runner_name, training_name] + + task = Task.init(project_name=project, task_name=task_name, tags=tags) + + if remote: + task.execute_remotely(queue_name="default") + + return task @hydra.main(version_base=None, config_path="configs", config_name="config") def main(cfg: DictConfig) -> None: - env_config = _build_env_config(cfg) - runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True)) + choices = HydraConfig.get().runtime.choices + # ClearML init — must happen before heavy work so remote execution + # can take over early. The remote worker re-runs the full script; + # execute_remotely() is a no-op on the worker side. training_dict = OmegaConf.to_container(cfg.training, resolve=True) - # Build ClearML task name dynamically from Hydra config group choices - if not training_dict.get("clearml_task"): - choices = HydraConfig.get().runtime.choices - env_name = choices.get("env", "env") - runner_name = choices.get("runner", "runner") - training_name = choices.get("training", "algo") - training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}" + remote = training_dict.pop("remote", False) + task = _init_clearml(choices, remote=remote) + + env_name = choices.get("env", "cartpole") + env = _build_env(env_name, cfg) + runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True)) trainer_config = TrainerConfig(**training_dict) - env = CartPoleEnv(env_config) runner = MuJoCoRunner(env=env, config=runner_config) trainer = Trainer(runner=runner, config=trainer_config) @@ -41,6 +82,7 @@ def main(cfg: DictConfig) -> None: trainer.train() finally: trainer.close() + task.close() if __name__ == "__main__":