✨ add rotary cartpole env
This commit is contained in:
@@ -57,3 +57,10 @@ class BaseEnv(abc.ABC, Generic[T]):
|
||||
|
||||
def compute_truncations(self, step_counts: torch.Tensor) -> torch.Tensor:
|
||||
return step_counts >= self.config.max_steps
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
"""Return the default joint positions for reset.
|
||||
Override in subclass if the URDF zero pose doesn't match
|
||||
the desired initial state (e.g. pendulum hanging down).
|
||||
Returns None to use the URDF default (all zeros)."""
|
||||
return None
|
||||
|
||||
@@ -90,8 +90,9 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
raise NotImplementedError("Render method not implemented for this runner.")
|
||||
def render(self, env_idx: int = 0):
|
||||
"""Offscreen render → RGB numpy array. Override in subclass."""
|
||||
raise NotImplementedError("Render not implemented for this runner.")
|
||||
|
||||
def close(self) -> None:
|
||||
self._sim_close()
|
||||
94
src/envs/rotary_cartpole.py
Normal file
94
src/envs/rotary_cartpole.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import dataclasses
|
||||
import torch
|
||||
from src.core.env import BaseEnv, BaseEnvConfig
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleState:
|
||||
motor_angle: torch.Tensor # (num_envs,)
|
||||
motor_vel: torch.Tensor # (num_envs,)
|
||||
pendulum_angle: torch.Tensor # (num_envs,)
|
||||
pendulum_vel: torch.Tensor # (num_envs,)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RotaryCartPoleConfig(BaseEnvConfig):
|
||||
"""Rotary inverted pendulum (Furuta pendulum) task config.
|
||||
|
||||
The motor rotates the arm horizontally; the pendulum swings freely
|
||||
at the arm tip. Goal: swing the pendulum up and balance it upright.
|
||||
"""
|
||||
# Reward shaping
|
||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||
|
||||
|
||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
"""Furuta pendulum / rotary inverted pendulum environment.
|
||||
|
||||
Kinematic chain: base_link ─(motor_joint, z)─► arm ─(pendulum_joint, y)─► pendulum
|
||||
|
||||
Observations (6):
|
||||
[sin(motor), cos(motor), sin(pendulum), cos(pendulum), motor_vel, pendulum_vel]
|
||||
Using sin/cos avoids discontinuities at ±π for continuous joints.
|
||||
|
||||
Actions (1):
|
||||
Torque applied to the motor_joint.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RotaryCartPoleConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-torch.inf, high=torch.inf, shape=(6,))
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def build_state(self, qpos: torch.Tensor, qvel: torch.Tensor) -> RotaryCartPoleState:
|
||||
return RotaryCartPoleState(
|
||||
motor_angle=qpos[:, 0],
|
||||
motor_vel=qvel[:, 0],
|
||||
pendulum_angle=qpos[:, 1],
|
||||
pendulum_vel=qvel[:, 1],
|
||||
)
|
||||
|
||||
def compute_observations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
return torch.stack([
|
||||
torch.sin(state.motor_angle),
|
||||
torch.cos(state.motor_angle),
|
||||
torch.sin(state.pendulum_angle),
|
||||
torch.cos(state.pendulum_angle),
|
||||
state.motor_vel,
|
||||
state.pendulum_vel,
|
||||
], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
# height: sin(θ) → -1 (down) to +1 (up)
|
||||
height = torch.sin(state.pendulum_angle)
|
||||
|
||||
# Upright reward: strongly rewards being near vertical.
|
||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||
# down (h=-1): 0.0
|
||||
# horiz (h= 0): 0.0
|
||||
# up (h=+1): 1.0
|
||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||
|
||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||
|
||||
return 5.0 * upright_reward - effort_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
# The agent must learn to swing up AND balance continuously.
|
||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
# The STL mesh is horizontal at qpos=0.
|
||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||
import math
|
||||
return [0.0, -math.pi / 2]
|
||||
@@ -4,7 +4,7 @@ from gymnasium import spaces
|
||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||
|
||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -20.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
Model.__init__(self, observation_space, action_space, device)
|
||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||
DeterministicMixin.__init__(self, clip_actions)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import dataclasses
|
||||
import tempfile
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
@@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
||||
abs_path = os.path.abspath(model_path)
|
||||
model_dir = os.path.dirname(abs_path)
|
||||
is_urdf = abs_path.lower().endswith(".urdf")
|
||||
|
||||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||
# temporary copy. The original URDF stays clean and simulator-agnostic.
|
||||
if is_urdf:
|
||||
tree = ET.parse(abs_path)
|
||||
root = tree.getroot()
|
||||
# Detect the mesh subdirectory from the first mesh filename
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
dirname = os.path.dirname(fn)
|
||||
if dirname:
|
||||
meshdir = dirname
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
ET.SubElement(mj_ext, "compiler", attrib={
|
||||
"meshdir": meshdir,
|
||||
"balanceinertia": "true",
|
||||
})
|
||||
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
||||
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
||||
try:
|
||||
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
||||
finally:
|
||||
os.unlink(tmp_urdf)
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation
|
||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
||||
# Step 2: Export internal MJCF representation (save next to original
|
||||
# model so relative mesh/asset paths resolve correctly on reload)
|
||||
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
# Use torque actuator. Speed is limited by joint damping:
|
||||
# at steady state, vel_max = torque / damping.
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Add damping to actuated joints to limit max speed and
|
||||
# mimic real motor friction / back-EMF.
|
||||
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
||||
actuated_joints = {a.joint for a in actuators}
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("name") in actuated_joints:
|
||||
jnt.set("damping", "0.05")
|
||||
|
||||
# Disable self-collision on all geoms.
|
||||
# URDF mesh convex hulls often overlap at joints (especially
|
||||
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
|
||||
# causing phantom contact forces.
|
||||
for geom in root.iter("geom"):
|
||||
geom.set("contype", "0")
|
||||
geom.set("conaffinity", "0")
|
||||
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
with open(tmp_mjcf, "w") as f:
|
||||
f.write(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Step 4: Reload from modified MJCF
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
||||
if os.path.exists(tmp_mjcf):
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
|
||||
# Per-env smoothed ctrl state for EMA action filtering.
|
||||
# Models real motor inertia: ctrl can't reverse instantly.
|
||||
nu = self._model.nu
|
||||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
alpha = self.config.action_ema_alpha
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
data.ctrl[:] = actions_np[i]
|
||||
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
||||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||||
data.ctrl[:] = self._smooth_ctrl[i]
|
||||
for _ in range(self.config.substeps):
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||
|
||||
default_qpos = self.env.get_default_qpos(self._nq)
|
||||
|
||||
for i, env_id in enumerate(ids):
|
||||
data = self._data[env_id]
|
||||
mujoco.mj_resetData(self._model, data)
|
||||
|
||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
||||
# Set initial pose (env-specific, e.g. pendulum hanging down)
|
||||
if default_qpos is not None:
|
||||
data.qpos[:] = default_qpos
|
||||
|
||||
# Add small random perturbation for exploration
|
||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
# Reset smoothed ctrl so motor starts from rest
|
||||
self._smooth_ctrl[env_id][:] = 0.0
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
||||
self._viewer.close()
|
||||
self._viewer = None
|
||||
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
self._offscreen_renderer = None
|
||||
|
||||
self._data.clear()
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
if mode == "human":
|
||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
||||
self._viewer = mujoco.viewer.launch_passive(
|
||||
self._model, self._data[env_idx]
|
||||
)
|
||||
# Update visual geometry from current physics state
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._viewer.sync()
|
||||
return None
|
||||
elif mode == "rgb_array":
|
||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
||||
return torch.from_numpy(pixels)
|
||||
def render(self, env_idx: int = 0) -> np.ndarray | None:
|
||||
"""Offscreen render → RGB numpy array (H, W, 3)."""
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
return self._offscreen_renderer.render().copy()
|
||||
@@ -4,19 +4,22 @@ import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from clearml import Logger
|
||||
from gymnasium import spaces
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from skrl.resources.preprocessors.torch import RunningStandardScaler
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
from src.core.runner import BaseRunner
|
||||
from clearml import Task, Logger
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
from skrl.memories.torch import RandomMemory
|
||||
from src.models.mlp import SharedMLP
|
||||
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
|
||||
from skrl.trainers.torch import SequentialTrainer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainerConfig:
|
||||
# PPO
|
||||
rollout_steps: int = 2048
|
||||
learning_epochs: int = 8
|
||||
mini_batches: int = 4
|
||||
@@ -29,30 +32,27 @@ class TrainerConfig:
|
||||
|
||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
|
||||
# Video recording
|
||||
record_video_every: int = 10000 # record a video every N timesteps (0 = disabled)
|
||||
record_video_min_seconds: float = 10.0 # minimum video duration in seconds
|
||||
record_video_fps: int = 0 # 0 = auto-derive from simulation rate
|
||||
# Video recording (uploaded to ClearML)
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||
|
||||
clearml_project: str | None = None
|
||||
clearml_task: str | None = None
|
||||
|
||||
# ── Video-recording trainer ──────────────────────────────────────────
|
||||
|
||||
class VideoRecordingTrainer(SequentialTrainer):
|
||||
"""Subclass of skrl's SequentialTrainer that records videos periodically."""
|
||||
"""SequentialTrainer with periodic evaluation videos uploaded to ClearML."""
|
||||
|
||||
def __init__(self, env, agents, cfg=None, trainer_config: TrainerConfig | None = None):
|
||||
super().__init__(env=env, agents=agents, cfg=cfg)
|
||||
self._trainer_config = trainer_config
|
||||
self._tcfg = trainer_config
|
||||
self._video_dir = Path(tempfile.mkdtemp(prefix="rl_videos_"))
|
||||
|
||||
def single_agent_train(self) -> None:
|
||||
"""Override to add periodic video recording."""
|
||||
assert self.num_simultaneous_agents == 1
|
||||
assert self.env.num_agents == 1
|
||||
assert self.num_simultaneous_agents == 1 and self.env.num_agents == 1
|
||||
|
||||
states, infos = self.env.reset()
|
||||
|
||||
@@ -61,26 +61,17 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
disable=self.disable_progressbar,
|
||||
file=sys.stdout,
|
||||
):
|
||||
# Pre-interaction
|
||||
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
|
||||
|
||||
if not self.headless:
|
||||
self.env.render()
|
||||
|
||||
self.agents.record_transition(
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
terminated=terminated,
|
||||
truncated=truncated,
|
||||
infos=infos,
|
||||
timestep=timestep,
|
||||
timesteps=self.timesteps,
|
||||
states=states, actions=actions, rewards=rewards,
|
||||
next_states=next_states, terminated=terminated,
|
||||
truncated=truncated, infos=infos,
|
||||
timestep=timestep, timesteps=self.timesteps,
|
||||
)
|
||||
|
||||
if self.environment_info in infos:
|
||||
@@ -90,7 +81,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
|
||||
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)
|
||||
|
||||
# Reset environments
|
||||
# Auto-reset for multi-env; single-env resets manually
|
||||
if self.env.num_envs > 1:
|
||||
states = next_states
|
||||
else:
|
||||
@@ -100,111 +91,90 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
else:
|
||||
states = next_states
|
||||
|
||||
# Record video at intervals
|
||||
cfg = self._trainer_config
|
||||
# Periodic video recording
|
||||
if (
|
||||
cfg
|
||||
and cfg.record_video_every > 0
|
||||
and (timestep + 1) % cfg.record_video_every == 0
|
||||
self._tcfg
|
||||
and self._tcfg.record_video_every > 0
|
||||
and (timestep + 1) % self._tcfg.record_video_every == 0
|
||||
):
|
||||
self._record_video(timestep + 1)
|
||||
|
||||
def _get_video_fps(self) -> int:
|
||||
"""Derive video fps from the simulation rate, or use configured value."""
|
||||
cfg = self._trainer_config
|
||||
if cfg.record_video_fps > 0:
|
||||
return cfg.record_video_fps
|
||||
# Auto-derive from runner's simulation parameters
|
||||
runner = self.env
|
||||
dt = getattr(runner.config, "dt", 0.02)
|
||||
substeps = getattr(runner.config, "substeps", 1)
|
||||
# ── helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _get_fps(self) -> int:
|
||||
if self._tcfg and self._tcfg.record_video_fps > 0:
|
||||
return self._tcfg.record_video_fps
|
||||
dt = getattr(self.env.config, "dt", 0.02)
|
||||
substeps = getattr(self.env.config, "substeps", 1)
|
||||
return max(1, int(round(1.0 / (dt * substeps))))
|
||||
|
||||
def _record_video(self, timestep: int) -> None:
|
||||
"""Record evaluation episodes and upload to ClearML."""
|
||||
try:
|
||||
import imageio.v3 as iio
|
||||
except ImportError:
|
||||
try:
|
||||
import imageio as iio
|
||||
except ImportError:
|
||||
return
|
||||
return
|
||||
|
||||
cfg = self._trainer_config
|
||||
fps = self._get_video_fps()
|
||||
min_frames = int(cfg.record_video_min_seconds * fps)
|
||||
max_frames = min_frames * 3 # hard cap to prevent runaway recording
|
||||
fps = self._get_fps()
|
||||
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
frames: list[np.ndarray] = []
|
||||
|
||||
while len(frames) < min_frames and len(frames) < max_frames:
|
||||
obs, _ = self.env.reset()
|
||||
done = False
|
||||
steps = 0
|
||||
max_episode_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||
while not done and steps < max_episode_steps:
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
frame = self.env.render(mode="rgb_array")
|
||||
if frame is not None:
|
||||
frames.append(frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame)
|
||||
done = (terminated | truncated).any().item()
|
||||
steps += 1
|
||||
if len(frames) >= max_frames:
|
||||
break
|
||||
obs, _ = self.env.reset()
|
||||
for _ in range(max_steps):
|
||||
with torch.no_grad():
|
||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||
|
||||
frame = self.env.render()
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
if (terminated | truncated).any().item():
|
||||
break
|
||||
|
||||
if frames:
|
||||
video_path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(video_path, frames, fps=fps)
|
||||
path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||
iio.imwrite(path, frames, fps=fps)
|
||||
|
||||
logger = Logger.current_logger()
|
||||
if logger:
|
||||
logger.report_media(
|
||||
title="Training Video",
|
||||
series=f"step_{timestep}",
|
||||
local_path=video_path,
|
||||
iteration=timestep,
|
||||
"Training Video", f"step_{timestep}",
|
||||
local_path=path, iteration=timestep,
|
||||
)
|
||||
|
||||
# Reset back to training state after recording
|
||||
# Restore training state
|
||||
self.env.reset()
|
||||
|
||||
|
||||
# ── Main trainer ─────────────────────────────────────────────────────
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, runner: BaseRunner, config: TrainerConfig):
|
||||
self.runner = runner
|
||||
self.config = config
|
||||
|
||||
self._init_clearml()
|
||||
self._init_agent()
|
||||
|
||||
def _init_clearml(self) -> None:
|
||||
if self.config.clearml_project and self.config.clearml_task:
|
||||
self.clearml_task = Task.init(
|
||||
project_name=self.config.clearml_project,
|
||||
task_name=self.config.clearml_task,
|
||||
)
|
||||
else:
|
||||
self.clearml_task = None
|
||||
|
||||
def _init_agent(self) -> None:
|
||||
device: torch.device = self.runner.device
|
||||
obs_space: spaces.Space = self.runner.observation_space
|
||||
act_space: spaces.Space = self.runner.action_space
|
||||
num_envs: int = self.runner.num_envs
|
||||
device = self.runner.device
|
||||
obs_space = self.runner.observation_space
|
||||
act_space = self.runner.action_space
|
||||
|
||||
self.memory: RandomMemory = RandomMemory(memory_size=self.config.rollout_steps, num_envs=num_envs, device=device)
|
||||
self.memory = RandomMemory(
|
||||
memory_size=self.config.rollout_steps,
|
||||
num_envs=self.runner.num_envs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.model: SharedMLP = SharedMLP(
|
||||
self.model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=self.config.hidden_sizes,
|
||||
initial_log_std=0.5,
|
||||
min_log_std=-2.0,
|
||||
)
|
||||
|
||||
models = {
|
||||
"policy": self.model,
|
||||
"value": self.model,
|
||||
}
|
||||
models = {"policy": self.model, "value": self.model}
|
||||
|
||||
agent_cfg = PPO_DEFAULT_CONFIG.copy()
|
||||
agent_cfg.update({
|
||||
@@ -217,9 +187,19 @@ class Trainer:
|
||||
"ratio_clip": self.config.clip_ratio,
|
||||
"value_loss_scale": self.config.value_loss_scale,
|
||||
"entropy_loss_scale": self.config.entropy_loss_scale,
|
||||
"state_preprocessor": RunningStandardScaler,
|
||||
"state_preprocessor_kwargs": {"size": obs_space, "device": device},
|
||||
"value_preprocessor": RunningStandardScaler,
|
||||
"value_preprocessor_kwargs": {"size": 1, "device": device},
|
||||
})
|
||||
# Wire up logging frequency: write_interval is in timesteps.
|
||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||
self.config.total_timesteps // 10, self.config.rollout_steps
|
||||
)
|
||||
|
||||
self.agent: PPO = PPO(
|
||||
self.agent = PPO(
|
||||
models=models,
|
||||
memory=self.memory,
|
||||
observation_space=obs_space,
|
||||
@@ -238,6 +218,4 @@ class Trainer:
|
||||
trainer.train()
|
||||
|
||||
def close(self) -> None:
|
||||
self.runner.close()
|
||||
if self.clearml_task:
|
||||
self.clearml_task.close()
|
||||
self.runner.close()
|
||||
Reference in New Issue
Block a user