initial commit

This commit is contained in:
2026-03-06 22:19:44 +01:00
commit c8f28ffbcc
17 changed files with 811 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
outputs/
.vscode/
runs/

1
.python-version Normal file
View File

@@ -0,0 +1 @@
RL-Framework-7914bb

View File

@@ -0,0 +1,64 @@
<?xml version="1.0" encoding="utf-8"?>
<robot name="cartpole">
<!-- World link (fixed base) -->
<link name="world"/>
<!-- Cart (slides along x-axis) -->
<link name="cart">
<inertial>
<mass value="1.0"/>
<inertia ixx="0.001" ixy="0" ixz="0" iyy="0.001" iyz="0" izz="0.001"/>
</inertial>
<visual>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</visual>
<collision>
<geometry>
<box size="0.3 0.2 0.1"/>
</geometry>
</collision>
</link>
<!-- Cart slides along x-axis -->
<joint name="cart_joint" type="prismatic">
<parent link="world"/>
<child link="cart"/>
<axis xyz="1 0 0"/>
<limit lower="-2.4" upper="2.4" effort="100" velocity="10"/>
</joint>
<!-- Pole (rotates around y-axis, attached on top of cart) -->
<link name="pole">
<inertial>
<origin xyz="0 0 0.3"/>
<mass value="0.1"/>
<inertia ixx="0.003" ixy="0" ixz="0" iyy="0.003" iyz="0" izz="0.0001"/>
</inertial>
<visual>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</visual>
<collision>
<origin xyz="0 0 0.3"/>
<geometry>
<cylinder radius="0.02" length="0.6"/>
</geometry>
</collision>
</link>
<!-- Pole rotates freely (no motor) -->
<joint name="pole_joint" type="revolute">
<parent link="cart"/>
<child link="pole"/>
<origin xyz="0 0 0.05"/>
<axis xyz="0 1 0"/>
<limit lower="-6.28" upper="6.28" effort="0" velocity="100"/>
<dynamics damping="0.0" friction="0.0"/>
</joint>
</robot>

5
configs/config.yaml Normal file
View File

@@ -0,0 +1,5 @@
defaults:
- env: cartpole
- runner: mujoco
- training: ppo
- _self_

11
configs/env/cartpole.yaml vendored Normal file
View File

@@ -0,0 +1,11 @@
max_steps: 500
angle_threshold: 0.418
cart_limit: 2.4
reward_alive: 1.0
reward_pole_upright_scale: 1.0
reward_action_penalty_scale: 0.01
model_path: assets/cartpole/cartpole.urdf
actuators:
- joint: cart_joint
gear: 10.0
ctrl_range: [-1.0, 1.0]

View File

@@ -0,0 +1,4 @@
num_envs: 16
device: cpu
dt: 0.02
substeps: 2

13
configs/training/ppo.yaml Normal file
View File

@@ -0,0 +1,13 @@
hidden_sizes: [128, 128]
total_timesteps: 1000000
rollout_steps: 1024
learning_epochs: 4
mini_batches: 4
discount_factor: 0.99
gae_lambda: 0.95
learning_rate: 0.0003
clip_ratio: 0.2
value_loss_scale: 0.5
entropy_loss_scale: 0.01
log_interval: 10
clearml_project: RL-Framework

8
requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
torch
gymnasium
hydra-core
omegaconf
mujoco
skrl[torch]
clearml
pytest

0
src/core/__init__.py Normal file
View File

59
src/core/env.py Normal file
View File

@@ -0,0 +1,59 @@
import abc
import dataclasses
from typing import TypeVar, Generic, Any
from gymnasium import spaces
import torch
import pathlib
T = TypeVar("T")
@dataclasses.dataclass
class ActuatorConfig:
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
Kept in the env config (not runner config) because actuators define what the robot
can do, which determines action space — a task-level concept.
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
@dataclasses.dataclass
class BaseEnvConfig:
max_steps: int = 1000
model_path: pathlib.Path | None = None
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
class BaseEnv(abc.ABC, Generic[T]):
def __init__(self, config: BaseEnvConfig):
self.config = config
@property
@abc.abstractmethod
def observation_space(self) -> spaces.Space:
...
@property
@abc.abstractmethod
def action_space(self) -> spaces.Space:
...
@abc.abstractmethod
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> Any:
...
@abc.abstractmethod
def compute_observations(self, state: Any) -> torch.Tensor:
...
@abc.abstractmethod
def compute_rewards(self, state: Any, actions: torch.Tensor) -> torch.Tensor:
...
@abc.abstractmethod
def compute_terminations(self, state: Any) -> torch.Tensor:
...
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
return step_counts >= self.config.max_steps

97
src/core/runner.py Normal file
View File

@@ -0,0 +1,97 @@
import dataclasses
import abc
from typing import Any, Generic, TypeVar
from src.core.env import BaseEnv
import torch
T = TypeVar("T")
@dataclasses.dataclass
class BaseRunnerConfig:
num_envs: int = 1
device: str = "cpu"
class BaseRunner(abc.ABC, Generic[T]):
def __init__(self, env: BaseEnv, config: T) -> None:
self.env = env
self.config = config
self._sim_initialize(config)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
self.num_agents: int = 1 # single-agent RL (required by skrl)
self.step_counts = torch.zeros(
self.config.num_envs, dtype=torch.long, device=self.config.device
)
@property
@abc.abstractmethod
def num_envs(self) -> int:
...
@property
@abc.abstractmethod
def device(self) -> torch.device:
...
@abc.abstractmethod
def _sim_initialize(self, config: T) -> None:
...
@abc.abstractmethod
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
@abc.abstractmethod
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
@abc.abstractmethod
def _sim_close(self) -> None:
...
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device)
qpos, qvel = self._sim_reset(all_ids)
self.step_counts.zero_()
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
return obs, {}
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
qpos, qvel = self._sim_step(actions)
self.step_counts += 1
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
rewards = self.env.compute_rewards(state, actions)
terminated = self.env.compute_terminations(state)
truncated = self.env.compute_truncations(self.step_counts)
info: dict[str, Any] = {}
done = terminated | truncated
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
if done_ids.numel() > 0:
info["final_observations"] = obs[done_ids].clone()
info["final_env_ids"] = done_ids.clone()
reset_qpos, reset_qvel = self._sim_reset(done_ids)
self.step_counts[done_ids] = 0
reset_state = self.env.build_state(reset_qpos, reset_qvel)
obs[done_ids] = self.env.compute_observations(reset_state)
# skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
raise NotImplementedError("Render method not implemented for this runner.")
def close(self) -> None:
self._sim_close()

53
src/envs/cartpole.py Normal file
View File

@@ -0,0 +1,53 @@
import dataclasses
import torch
from src.core.env import BaseEnv, BaseEnvConfig
from gymnasium import spaces
@dataclasses.dataclass
class CartPoleState:
cart_pos: torch.float # (num_envs,)
cart_vel: torch.float # (num_envs,)
pole_angle: torch.float # (num_envs,)
pole_vel: torch.float # (num_envs,)
@dataclasses.dataclass
class CartPoleConfig(BaseEnvConfig):
"""CartPole task config. All values come from Hydra YAML."""
angle_threshold: float = 0.418 # ~24 degrees
cart_limit: float = 2.4
reward_alive: float = 1.0
reward_pole_upright_scale: float = 1.0
reward_action_penalty_scale: float = 0.01
class CartPoleEnv(BaseEnv[CartPoleConfig]):
def __init__(self, config: CartPoleConfig):
super().__init__(config)
@property
def observation_space(self) -> torch.Tensor:
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(4,))
@property
def action_space(self) -> torch.Tensor:
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> CartPoleState:
return CartPoleState(
cart_pos=qpos[:, 0],
cart_vel=qvel[:, 0],
pole_angle=qpos[:, 1],
pole_vel=qvel[:, 1],
)
def compute_observations(self, state: CartPoleState) -> torch.Tensor:
return torch.stack([state.cart_pos, state.cart_vel, state.pole_angle, state.pole_vel], dim=-1)
def compute_rewards(self, state: CartPoleState, actions: torch.Tensor) -> torch.Tensor:
upright = self.config.reward_pole_upright_scale * torch.cos(state.pole_angle)
action_penalty = self.config.reward_action_penalty_scale * torch.sum(actions**2, dim=-1)
return self.config.reward_alive + upright - action_penalty
def compute_terminations(self, state: CartPoleState) -> torch.Tensor:
pole_fallen = torch.abs(state.pole_angle) > self.config.angle_threshold
cart_out_of_bounds = torch.abs(state.cart_pos) > self.config.cart_limit
return pole_fallen | cart_out_of_bounds

0
src/models/__init__.py Normal file
View File

48
src/models/mlp.py Normal file
View File

@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
from gymnasium import spaces
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
DeterministicMixin.__init__(self, clip_actions)
layers = []
in_dim: int = self.num_observations
for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_dim, hidden_size))
layers.append(nn.ELU())
in_dim = hidden_size
self.net: nn.Sequential = nn.Sequential(*layers)
# Policy head
self.mean_layer = nn.Linear(in_dim, self.num_actions)
self.log_std_parameter: nn.Parameter = nn.Parameter(torch.full((self.num_actions,), initial_log_std))
# Value head
self.value_layer = nn.Linear(in_dim, 1)
self._shared_output: torch.Tensor | None = None
def act(self, inputs: dict[str, torch.Tensor], role: str = "") -> tuple[torch.Tensor, ...]:
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)
def compute(
self, inputs: dict[str, torch.Tensor], role: str = ""
) -> tuple[torch.Tensor, ...]:
if role == "policy":
self._shared_output = self.net(inputs["states"])
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value":
shared_output = (
self._shared_output
if self._shared_output is not None
else self.net(inputs["states"])
)
self._shared_output = None
return self.value_layer(shared_output), {}

155
src/runners/mujoco.py Normal file
View File

@@ -0,0 +1,155 @@
import dataclasses
import tempfile
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
import mujoco.viewer
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
num_envs: int = 16
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
@staticmethod
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
"""Load a URDF (or MJCF) file and programmatically inject actuators.
Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, add <actuator> elements, reload
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
# Step 1: Load URDF/MJCF as-is (no actuators)
model_raw = mujoco.MjModel.from_xml_path(model_path)
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation
tmp_mjcf = tempfile.mktemp(suffix=".xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
finally:
import os
os.unlink(tmp_mjcf)
# Step 3: Inject actuators into the MJCF XML
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Step 4: Reload from modified MJCF
modified_xml = ET.tostring(root, encoding="unicode")
return mujoco.MjModel.from_xml_string(modified_xml)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified in the environment config")
actuators = self.env.config.actuators
self._model = self._load_model_with_actuators(str(model_path), actuators)
self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
self._nq = self._model.nq
self._nv = self._model.nv
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
data.ctrl[:] = actions_np[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ids = env_ids.cpu().numpy()
n = len(ids)
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Add small random perturbation so the pole doesn't start perfectly upright
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_close(self) -> None:
if hasattr(self, "_viewer") and self._viewer is not None:
self._viewer.close()
self._viewer = None
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
if mode == "human":
if not hasattr(self, "_viewer") or self._viewer is None:
self._viewer = mujoco.viewer.launch_passive(
self._model, self._data[env_idx]
)
# Update visual geometry from current physics state
mujoco.mj_forward(self._model, self._data[env_idx])
self._viewer.sync()
return None
elif mode == "rgb_array":
# Cache the offscreen renderer to avoid create/destroy overhead
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
return torch.from_numpy(pixels)

243
src/training/trainer.py Normal file
View File

@@ -0,0 +1,243 @@
import dataclasses
import sys
import tempfile
from pathlib import Path
import numpy as np
import tqdm
from src.core.runner import BaseRunner
from clearml import Task, Logger
import torch
from gymnasium import spaces
from skrl.memories.torch import RandomMemory
from src.models.mlp import SharedMLP
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.trainers.torch import SequentialTrainer
@dataclasses.dataclass
class TrainerConfig:
rollout_steps: int = 2048
learning_epochs: int = 8
mini_batches: int = 4
discount_factor: float = 0.99
gae_lambda: float = 0.95
learning_rate: float = 3e-4
clip_ratio: float = 0.2
value_loss_scale: float = 0.5
entropy_loss_scale: float = 0.01
hidden_sizes: tuple[int, ...] = (64, 64)
total_timesteps: int = 1_000_000
log_interval: int = 10
# Video recording
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
clearml_project: str | None = None
clearml_task: str | None = None
class VideoRecordingTrainer(SequentialTrainer):
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
super().__init__(env=env, agents=agents, cfg=cfg)
self._trainer_config = trainer_config
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
def single_agent_train(self) -> None:
"""Override to add periodic video recording."""
assert self.num_simultaneous_agents == 1
assert self.env.num_agents == 1
states, infos = self.env.reset()
for timestep in tqdm.tqdm(
range(self.initial_timestep, self.timesteps),
disable=self.disable_progressbar,
file=sys.stdout,
):
# Pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
with torch.no_grad():
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
if not self.headless:
self.env.render()
self.agents.record_transition(
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
terminated=terminated,
truncated=truncated,
infos=infos,
timestep=timestep,
timesteps=self.timesteps,
)
if self.environment_info in infos:
for k, v in infos[self.environment_info].items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
self.agents.track_data(f"Info / {k}", v.item())
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
# Reset environments
if self.env.num_envs > 1:
states = next_states
else:
if terminated.any() or truncated.any():
with torch.no_grad():
states, infos = self.env.reset()
else:
states = next_states
# Record video at intervals
cfg = self._trainer_config
if (
cfg
and cfg.record_video_every > 0
and (timestep + 1) % cfg.record_video_every == 0
):
self._record_video(timestep + 1)
def _get_video_fps(self) -> int:
"""Derive video fps from the simulation rate, or use configured value."""
cfg = self._trainer_config
if cfg.record_video_fps > 0:
return cfg.record_video_fps
# Auto-derive from runner's simulation parameters
runner = self.env
dt = getattr(runner.config, "dt", 0.02)
substeps = getattr(runner.config, "substeps", 1)
return max(1, int(round(1.0 / (dt * substeps))))
def _record_video(self, timestep: int) -> None:
"""Record evaluation episodes and upload to ClearML."""
try:
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
except ImportError:
return
cfg = self._trainer_config
fps = self._get_video_fps()
min_frames = int(cfg.record_video_min_seconds * fps)
max_frames = min_frames * 3 # hard cap to prevent runaway recording
frames: list[np.ndarray] = []
while len(frames) < min_frames and len(frames) < max_frames:
obs, _ = self.env.reset()
done = False
steps = 0
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
while not done and steps < max_episode_steps:
with torch.no_grad():
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action)
frame = self.env.render(mode="rgb_array")
if frame is not None:
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
done = (terminated | truncated).any().item()
steps += 1
if len(frames) >= max_frames:
break
if frames:
video_path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(video_path, frames, fps=fps)
logger = Logger.current_logger()
if logger:
logger.report_media(
title="Training Video",
series=f"step_{timestep}",
local_path=video_path,
iteration=timestep,
)
# Reset back to training state after recording
self.env.reset()
class Trainer:
def __init__(self, runner: BaseRunner, config: TrainerConfig):
self.runner = runner
self.config = config
self._init_clearml()
self._init_agent()
def _init_clearml(self) -> None:
if self.config.clearml_project and self.config.clearml_task:
self.clearml_task = Task.init(
project_name=self.config.clearml_project,
task_name=self.config.clearml_task,
)
else:
self.clearml_task = None
def _init_agent(self) -> None:
device: torch.device = self.runner.device
obs_space: spaces.Space = self.runner.observation_space
act_space: spaces.Space = self.runner.action_space
num_envs: int = self.runner.num_envs
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
self.model: SharedMLP = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=self.config.hidden_sizes,
)
models = {
"policy": self.model,
"value": self.model,
}
agent_cfg = PPO_DEFAULT_CONFIG.copy()
agent_cfg.update({
"rollouts": self.config.rollout_steps,
"learning_epochs": self.config.learning_epochs,
"mini_batches": self.config.mini_batches,
"discount_factor": self.config.discount_factor,
"lambda": self.config.gae_lambda,
"learning_rate": self.config.learning_rate,
"ratio_clip": self.config.clip_ratio,
"value_loss_scale": self.config.value_loss_scale,
"entropy_loss_scale": self.config.entropy_loss_scale,
})
self.agent: PPO = PPO(
models=models,
memory=self.memory,
observation_space=obs_space,
action_space=act_space,
device=device,
cfg=agent_cfg,
)
def train(self) -> None:
trainer = VideoRecordingTrainer(
env=self.runner,
agents=self.agent,
cfg={"timesteps": self.config.total_timesteps},
trainer_config=self.config,
)
trainer.train()
def close(self) -> None:
self.runner.close()
if self.clearml_task:
self.clearml_task.close()

47
train.py Normal file
View File

@@ -0,0 +1,47 @@
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig
from src.core.env import ActuatorConfig
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return CartPoleConfig(**env_dict)
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
env_config = _build_env_config(cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
# Build ClearML task name dynamically from Hydra config group choices
if not training_dict.get("clearml_task"):
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "env")
runner_name = choices.get("runner", "runner")
training_name = choices.get("training", "algo")
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
trainer_config = TrainerConfig(**training_dict)
env = CartPoleEnv(env_config)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
if __name__ == "__main__":
main()