✨ initial commit
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
outputs/
|
||||
.vscode/
|
||||
runs/
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
RL-Framework-7914bb
|
||||
64
assets/cartpole/cartpole.urdf
Normal file
64
assets/cartpole/cartpole.urdf
Normal file
@@ -0,0 +1,64 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<robot name="cartpole">
|
||||
|
||||
<!-- World link (fixed base) -->
|
||||
<link name="world"/>
|
||||
|
||||
<!-- Cart (slides along x-axis) -->
|
||||
<link name="cart">
|
||||
<inertial>
|
||||
<mass value="1.0"/>
|
||||
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<geometry>
|
||||
<box size="0.3 0.2 0.1"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<geometry>
|
||||
<box size="0.3 0.2 0.1"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Cart slides along x-axis -->
|
||||
<joint name="cart_joint" type="prismatic">
|
||||
<parent link="world"/>
|
||||
<child link="cart"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
|
||||
</joint>
|
||||
|
||||
<!-- Pole (rotates around y-axis, attached on top of cart) -->
|
||||
<link name="pole">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<mass value="0.1"/>
|
||||
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.02" length="0.6"/>
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0.3"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.02" length="0.6"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Pole rotates freely (no motor) -->
|
||||
<joint name="pole_joint" type="revolute">
|
||||
<parent link="cart"/>
|
||||
<child link="pole"/>
|
||||
<origin xyz="0 0 0.05"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
|
||||
<dynamics damping="0.0" friction="0.0"/>
|
||||
</joint>
|
||||
|
||||
</robot>
|
||||
5
configs/config.yaml
Normal file
5
configs/config.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
defaults:
|
||||
- env: cartpole
|
||||
- runner: mujoco
|
||||
- training: ppo
|
||||
- _self_
|
||||
11
configs/env/cartpole.yaml
vendored
Normal file
11
configs/env/cartpole.yaml
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
max_steps: 500
|
||||
angle_threshold: 0.418
|
||||
cart_limit: 2.4
|
||||
reward_alive: 1.0
|
||||
reward_pole_upright_scale: 1.0
|
||||
reward_action_penalty_scale: 0.01
|
||||
model_path: assets/cartpole/cartpole.urdf
|
||||
actuators:
|
||||
- joint: cart_joint
|
||||
gear: 10.0
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
4
configs/runner/mujoco.yaml
Normal file
4
configs/runner/mujoco.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
num_envs: 16
|
||||
device: cpu
|
||||
dt: 0.02
|
||||
substeps: 2
|
||||
13
configs/training/ppo.yaml
Normal file
13
configs/training/ppo.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
hidden_sizes: [128, 128]
|
||||
total_timesteps: 1000000
|
||||
rollout_steps: 1024
|
||||
learning_epochs: 4
|
||||
mini_batches: 4
|
||||
discount_factor: 0.99
|
||||
gae_lambda: 0.95
|
||||
learning_rate: 0.0003
|
||||
clip_ratio: 0.2
|
||||
value_loss_scale: 0.5
|
||||
entropy_loss_scale: 0.01
|
||||
log_interval: 10
|
||||
clearml_project: RL-Framework
|
||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
torch
|
||||
gymnasium
|
||||
hydra-core
|
||||
omegaconf
|
||||
mujoco
|
||||
skrl[torch]
|
||||
clearml
|
||||
pytest
|
||||
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
59
src/core/env.py
Normal file
59
src/core/env.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
from typing import TypeVar, Generic, Any
|
||||
from gymnasium import spaces
|
||||
import torch
|
||||
import pathlib
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorConfig:
|
||||
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
|
||||
Kept in the env config (not runner config) because actuators define what the robot
|
||||
can do, which determines action space — a task-level concept.
|
||||
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
|
||||
joint: str = ""
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseEnvConfig:
|
||||
max_steps: int = 1000
|
||||
model_path: pathlib.Path | None = None
|
||||
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
|
||||
|
||||
class BaseEnv(abc.ABC, Generic[T]):
|
||||
def __init__(self, config: BaseEnvConfig):
|
||||
self.config = config
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_space(self) -> spaces.Space:
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_space(self) -> spaces.Space:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> Any:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_observations(self, state: Any) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_terminations(self, state: Any) -> torch.Tensor:
|
||||
...
|
||||
|
||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||
return step_counts >= self.config.max_steps
|
||||
97
src/core/runner.py
Normal file
97
src/core/runner.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import dataclasses
|
||||
import abc
|
||||
from typing import Any, Generic, TypeVar
|
||||
from src.core.env import BaseEnv
|
||||
import torch
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseRunnerConfig:
|
||||
num_envs: int = 1
|
||||
device: str = "cpu"
|
||||
|
||||
class BaseRunner(abc.ABC, Generic[T]):
|
||||
def __init__(self, env: BaseEnv, config: T) -> None:
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
self._sim_initialize(config)
|
||||
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = self.env.action_space
|
||||
self.num_agents: int = 1 # single-agent RL (required by skrl)
|
||||
|
||||
self.step_counts = torch.zeros(
|
||||
self.config.num_envs, dtype=torch.long, device=self.config.device
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def num_envs(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def device(self) -> torch.device:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_initialize(self, config: T) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_close(self) -> None:
|
||||
...
|
||||
|
||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
qpos, qvel = self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
return obs, {}
|
||||
|
||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
qpos, qvel = self._sim_step(actions)
|
||||
self.step_counts += 1
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
rewards = self.env.compute_rewards(state, actions)
|
||||
terminated = self.env.compute_terminations(state)
|
||||
truncated = self.env.compute_truncations(self.step_counts)
|
||||
|
||||
info: dict[str, Any] = {}
|
||||
|
||||
done = terminated | truncated
|
||||
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if done_ids.numel() > 0:
|
||||
info["final_observations"] = obs[done_ids].clone()
|
||||
info["final_env_ids"] = done_ids.clone()
|
||||
|
||||
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
||||
self.step_counts[done_ids] = 0
|
||||
|
||||
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
||||
obs[done_ids] = self.env.compute_observations(reset_state)
|
||||
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
raise NotImplementedError("Render method not implemented for this runner.")
|
||||
|
||||
def close(self) -> None:
|
||||
self._sim_close()
|
||||
53
src/envs/cartpole.py
Normal file
53
src/envs/cartpole.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from gymnasium import spaces
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CartPoleState:
|
||||
cart_pos: torch.float # (num_envs,)
|
||||
cart_vel: torch.float # (num_envs,)
|
||||
pole_angle: torch.float # (num_envs,)
|
||||
pole_vel: torch.float # (num_envs,)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CartPoleConfig(BaseEnvConfig):
|
||||
"""CartPole task config. All values come from Hydra YAML."""
|
||||
angle_threshold: float = 0.418 # ~24 degrees
|
||||
cart_limit: float = 2.4
|
||||
reward_alive: float = 1.0
|
||||
reward_pole_upright_scale: float = 1.0
|
||||
reward_action_penalty_scale: float = 0.01
|
||||
|
||||
class CartPoleEnv(BaseEnv[CartPoleConfig]):
|
||||
def __init__(self, config: CartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@property
|
||||
def observation_space(self) -> torch.Tensor:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> torch.Tensor:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
|
||||
return CartPoleState(
|
||||
cart_pos=qpos[:, 0],
|
||||
cart_vel=qvel[:, 0],
|
||||
pole_angle=qpos[:, 1],
|
||||
pole_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
|
||||
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
|
||||
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
|
||||
return self.config.reward_alive + upright - action_penalty
|
||||
|
||||
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
|
||||
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
|
||||
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
|
||||
return pole_fallen | cart_out_of_bounds
|
||||
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
48
src/models/mlp.py
Normal file
48
src/models/mlp.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gymnasium import spaces
|
||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||
|
||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
Model.__init__(self, observation_space, action_space, device)
|
||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||
DeterministicMixin.__init__(self, clip_actions)
|
||||
|
||||
layers = []
|
||||
in_dim: int = self.num_observations
|
||||
for hidden_size in hidden_sizes:
|
||||
layers.append(nn.Linear(in_dim, hidden_size))
|
||||
layers.append(nn.ELU())
|
||||
in_dim = hidden_size
|
||||
self.net: nn.Sequential = nn.Sequential(*layers)
|
||||
|
||||
# Policy head
|
||||
self.mean_layer = nn.Linear(in_dim, self.num_actions)
|
||||
self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std))
|
||||
|
||||
# Value head
|
||||
self.value_layer = nn.Linear(in_dim, 1)
|
||||
self._shared_output: torch.Tensor | None = None
|
||||
|
||||
|
||||
def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]:
|
||||
if role == "policy":
|
||||
return GaussianMixin.act(self, inputs, role)
|
||||
elif role == "value":
|
||||
return DeterministicMixin.act(self, inputs, role)
|
||||
|
||||
def compute(
|
||||
self, inputs: dict[str, torch.Tensor], role: str = ""
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if role == "policy":
|
||||
self._shared_output = self.net(inputs["states"])
|
||||
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
|
||||
elif role == "value":
|
||||
shared_output = (
|
||||
self._shared_output
|
||||
if self._shared_output is not None
|
||||
else self.net(inputs["states"])
|
||||
)
|
||||
self._shared_output = None
|
||||
return self.value_layer(shared_output), {}
|
||||
155
src/runners/mujoco.py
Normal file
155
src/runners/mujoco.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import dataclasses
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
num_envs: int = 16
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
super().__init__(env, config)
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return self.config.num_envs
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(self.config.device)
|
||||
|
||||
@staticmethod
|
||||
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) file and programmatically inject actuators.
|
||||
|
||||
Two-step approach required because MuJoCo's URDF parser ignores
|
||||
<actuator> in the <mujoco> extension block:
|
||||
1. Load the URDF → MuJoCo converts it to internal MJCF
|
||||
2. Export the MJCF XML, add <actuator> elements, reload
|
||||
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation
|
||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Step 4: Reload from modified MJCF
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
if model_path is None:
|
||||
raise ValueError("model_path must be specified in the environment config")
|
||||
|
||||
actuators = self.env.config.actuators
|
||||
self._model = self._load_model_with_actuators(str(model_path), actuators)
|
||||
self._model.opt.timestep = config.dt
|
||||
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
|
||||
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
data.ctrl[:] = actions_np[i]
|
||||
for _ in range(self.config.substeps):
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
return (
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
ids = env_ids.cpu().numpy()
|
||||
n = len(ids)
|
||||
|
||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||
|
||||
for i, env_id in enumerate(ids):
|
||||
data = self._data[env_id]
|
||||
mujoco.mj_resetData(self._model, data)
|
||||
|
||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
return (
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
||||
self._viewer.close()
|
||||
self._viewer = None
|
||||
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
self._offscreen_renderer = None
|
||||
|
||||
self._data.clear()
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
if mode == "human":
|
||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
||||
self._viewer = mujoco.viewer.launch_passive(
|
||||
self._model, self._data[env_idx]
|
||||
)
|
||||
# Update visual geometry from current physics state
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._viewer.sync()
|
||||
return None
|
||||
elif mode == "rgb_array":
|
||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
||||
return torch.from_numpy(pixels)
|
||||
243
src/training/trainer.py
Normal file
243
src/training/trainer.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import dataclasses
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from src.core.runner import BaseRunner
|
||||
from clearml import Task, Logger
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from src.models.mlp import SharedMLP
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerConfig:
|
||||
rollout_steps: int = 2048
|
||||
learning_epochs: int = 8
|
||||
mini_batches: int = 4
|
||||
discount_factor: float = 0.99
|
||||
gae_lambda: float = 0.95
|
||||
learning_rate: float = 3e-4
|
||||
clip_ratio: float = 0.2
|
||||
value_loss_scale: float = 0.5
|
||||
entropy_loss_scale: float = 0.01
|
||||
|
||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
|
||||
# Video recording
|
||||
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
|
||||
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
|
||||
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
|
||||
|
||||
clearml_project: str | None = None
|
||||
clearml_task: str | None = None
|
||||
|
||||
|
||||
class VideoRecordingTrainer(SequentialTrainer):
|
||||
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
|
||||
|
||||
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
||||
super().__init__(env=env, agents=agents, cfg=cfg)
|
||||
self._trainer_config = trainer_config
|
||||
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
||||
|
||||
def single_agent_train(self) -> None:
|
||||
"""Override to add periodic video recording."""
|
||||
assert self.num_simultaneous_agents == 1
|
||||
assert self.env.num_agents == 1
|
||||
|
||||
states, infos = self.env.reset()
|
||||
|
||||
for timestep in tqdm.tqdm(
|
||||
range(self.initial_timestep, self.timesteps),
|
||||
disable=self.disable_progressbar,
|
||||
file=sys.stdout,
|
||||
):
|
||||
# Pre-interaction
|
||||
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
||||
|
||||
if not self.headless:
|
||||
self.env.render()
|
||||
|
||||
self.agents.record_transition(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
infos=infos,
|
||||
timestep=timestep,
|
||||
timesteps=self.timesteps,
|
||||
)
|
||||
|
||||
if self.environment_info in infos:
|
||||
for k, v in infos[self.environment_info].items():
|
||||
if isinstance(v, torch.Tensor) and v.numel() == 1:
|
||||
self.agents.track_data(f"Info / {k}", v.item())
|
||||
|
||||
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
# Reset environments
|
||||
if self.env.num_envs > 1:
|
||||
states = next_states
|
||||
else:
|
||||
if terminated.any() or truncated.any():
|
||||
with torch.no_grad():
|
||||
states, infos = self.env.reset()
|
||||
else:
|
||||
states = next_states
|
||||
|
||||
# Record video at intervals
|
||||
cfg = self._trainer_config
|
||||
if (
|
||||
cfg
|
||||
and cfg.record_video_every > 0
|
||||
and (timestep + 1) % cfg.record_video_every == 0
|
||||
):
|
||||
self._record_video(timestep + 1)
|
||||
|
||||
def _get_video_fps(self) -> int:
|
||||
"""Derive video fps from the simulation rate, or use configured value."""
|
||||
cfg = self._trainer_config
|
||||
if cfg.record_video_fps > 0:
|
||||
return cfg.record_video_fps
|
||||
# Auto-derive from runner's simulation parameters
|
||||
runner = self.env
|
||||
dt = getattr(runner.config, "dt", 0.02)
|
||||
substeps = getattr(runner.config, "substeps", 1)
|
||||
return max(1, int(round(1.0 / (dt * substeps))))
|
||||
|
||||
def _record_video(self, timestep: int) -> None:
|
||||
"""Record evaluation episodes and upload to ClearML."""
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
except ImportError:
|
||||
try:
|
||||
import imageio as iio
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
cfg = self._trainer_config
|
||||
fps = self._get_video_fps()
|
||||
min_frames = int(cfg.record_video_min_seconds * fps)
|
||||
max_frames = min_frames * 3 # hard cap to prevent runaway recording
|
||||
frames: list[np.ndarray] = []
|
||||
|
||||
while len(frames) < min_frames and len(frames) < max_frames:
|
||||
obs, _ = self.env.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
while not done and steps < max_episode_steps:
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
frame = self.env.render(mode="rgb_array")
|
||||
if frame is not None:
|
||||
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
|
||||
done = (terminated | truncated).any().item()
|
||||
steps += 1
|
||||
if len(frames) >= max_frames:
|
||||
break
|
||||
|
||||
if frames:
|
||||
video_path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(video_path, frames, fps=fps)
|
||||
|
||||
logger = Logger.current_logger()
|
||||
if logger:
|
||||
logger.report_media(
|
||||
title="Training Video",
|
||||
series=f"step_{timestep}",
|
||||
local_path=video_path,
|
||||
iteration=timestep,
|
||||
)
|
||||
|
||||
# Reset back to training state after recording
|
||||
self.env.reset()
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
|
||||
self._init_clearml()
|
||||
self._init_agent()
|
||||
|
||||
def _init_clearml(self) -> None:
|
||||
if self.config.clearml_project and self.config.clearml_task:
|
||||
self.clearml_task = Task.init(
|
||||
project_name=self.config.clearml_project,
|
||||
task_name=self.config.clearml_task,
|
||||
)
|
||||
else:
|
||||
self.clearml_task = None
|
||||
|
||||
def _init_agent(self) -> None:
|
||||
device: torch.device = self.runner.device
|
||||
obs_space: spaces.Space = self.runner.observation_space
|
||||
act_space: spaces.Space = self.runner.action_space
|
||||
num_envs: int = self.runner.num_envs
|
||||
|
||||
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
|
||||
|
||||
self.model: SharedMLP = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=self.config.hidden_sizes,
|
||||
)
|
||||
|
||||
models = {
|
||||
"policy": self.model,
|
||||
"value": self.model,
|
||||
}
|
||||
|
||||
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
||||
agent_cfg.update({
|
||||
"rollouts": self.config.rollout_steps,
|
||||
"learning_epochs": self.config.learning_epochs,
|
||||
"mini_batches": self.config.mini_batches,
|
||||
"discount_factor": self.config.discount_factor,
|
||||
"lambda": self.config.gae_lambda,
|
||||
"learning_rate": self.config.learning_rate,
|
||||
"ratio_clip": self.config.clip_ratio,
|
||||
"value_loss_scale": self.config.value_loss_scale,
|
||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||
})
|
||||
|
||||
self.agent: PPO = PPO(
|
||||
models=models,
|
||||
memory=self.memory,
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
cfg=agent_cfg,
|
||||
)
|
||||
|
||||
def train(self) -> None:
|
||||
trainer = VideoRecordingTrainer(
|
||||
env=self.runner,
|
||||
agents=self.agent,
|
||||
cfg={"timesteps": self.config.total_timesteps},
|
||||
trainer_config=self.config,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
def close(self) -> None:
|
||||
self.runner.close()
|
||||
if self.clearml_task:
|
||||
self.clearml_task.close()
|
||||
47
train.py
Normal file
47
train.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import hydra
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
from src.core.env import ActuatorConfig
|
||||
|
||||
|
||||
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return CartPoleConfig(**env_dict)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
env_config = _build_env_config(cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
# Build ClearML task name dynamically from Hydra config group choices
|
||||
if not training_dict.get("clearml_task"):
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "env")
|
||||
runner_name = choices.get("runner", "runner")
|
||||
training_name = choices.get("training", "algo")
|
||||
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
env = CartPoleEnv(env_config)
|
||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user