diff --git a/configs/runner/mjx.yaml b/configs/runner/mjx.yaml new file mode 100644 index 0000000..891f353 --- /dev/null +++ b/configs/runner/mjx.yaml @@ -0,0 +1,5 @@ +num_envs: 1024 # MJX shines with many parallel envs +device: auto # auto = cuda if available, else cpu +dt: 0.002 +substeps: 20 +action_ema_alpha: 0.2 diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml index 861b8ba..418f2bd 100644 --- a/configs/runner/mujoco.yaml +++ b/configs/runner/mujoco.yaml @@ -1,5 +1,5 @@ num_envs: 64 -device: cpu +device: auto # auto = cuda if available, else cpu dt: 0.002 substeps: 20 action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse) diff --git a/requirements.txt b/requirements.txt index 4ac07a4..8466a62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ gymnasium hydra-core omegaconf mujoco +mujoco-mjx +jax skrl[torch] clearml imageio diff --git a/src/core/runner.py b/src/core/runner.py index ff00890..b28fe28 100644 --- a/src/core/runner.py +++ b/src/core/runner.py @@ -17,6 +17,10 @@ class BaseRunner(abc.ABC, Generic[T]): self.env = env self.config = config + # Resolve "auto" device before anything uses it + if getattr(self.config, "device", None) == "auto": + self.config.device = "cuda" if torch.cuda.is_available() else "cpu" + self._sim_initialize(config) self.observation_space = self.env.observation_space diff --git a/src/runners/mjx.py b/src/runners/mjx.py new file mode 100644 index 0000000..63b5f73 --- /dev/null +++ b/src/runners/mjx.py @@ -0,0 +1,227 @@ +"""GPU-batched MuJoCo simulation using MJX (JAX backend). + +MJX runs all environments in parallel on GPU via JAX, providing +~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+). + +Requirements: + pip install 'jax[cuda12]' # NVIDIA GPU + pip install jax # CPU fallback +""" + +import dataclasses + +import structlog +import torch + +try: + import jax + import jax.numpy as jnp + import mujoco + from mujoco import mjx +except ImportError as e: + raise ImportError( + "MJX runner requires JAX and MuJoCo MJX. Install with:\n" + " pip install 'jax[cuda12]' # GPU\n" + " pip install jax # CPU\n" + ) from e + +import numpy as np + +from src.core.env import BaseEnv +from src.core.runner import BaseRunner, BaseRunnerConfig +from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators + +log = structlog.get_logger() + + +@dataclasses.dataclass +class MJXRunnerConfig(BaseRunnerConfig): + num_envs: int = 1024 + device: str = "cuda" + dt: float = 0.002 + substeps: int = 20 + action_ema_alpha: float = 0.2 + + +class MJXRunner(BaseRunner[MJXRunnerConfig]): + """GPU-batched MuJoCo runner using MJX (JAX). + + Physics runs entirely on GPU via JAX; observations flow to + PyTorch through zero-copy DLPack transfers. + """ + + def __init__(self, env: BaseEnv, config: MJXRunnerConfig): + super().__init__(env, config) + + @property + def num_envs(self) -> int: + return self.config.num_envs + + @property + def device(self) -> torch.device: + return torch.device(self.config.device) + + # ── Initialization ─────────────────────────────────────────────── + + def _sim_initialize(self, config: MJXRunnerConfig) -> None: + model_path = self.env.config.model_path + if model_path is None: + raise ValueError("model_path must be specified") + + # Step 1: Load CPU model (reuses URDF → MJCF → actuator injection) + self._mj_model = MuJoCoRunner._load_model_with_actuators( + str(model_path), self.env.config.actuators, + ) + self._mj_model.opt.timestep = config.dt + self._nq = self._mj_model.nq + self._nv = self._mj_model.nv + self._nu = self._mj_model.nu + + # Step 2: Put model on GPU + self._mjx_model = mjx.put_model(self._mj_model) + + # Step 3: Default reset state on GPU + default_data = mujoco.MjData(self._mj_model) + default_qpos = self.env.get_default_qpos(self._nq) + if default_qpos is not None: + default_data.qpos[:] = default_qpos + mujoco.mj_forward(self._mj_model, default_data) + self._default_mjx_data = mjx.put_data(self._mj_model, default_data) + + # Step 4: Initialise all environments with small perturbations + self._rng = jax.random.PRNGKey(42) + self._batch_data = self._make_batched_data(config.num_envs) + + # Step 5: EMA ctrl state (on GPU as JAX array) + self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu)) + + # Step 6: JIT-compile the hot-path functions + self._compile_jit_fns(config.substeps) + + # Keep one CPU MjData for offscreen rendering + self._render_data = mujoco.MjData(self._mj_model) + + log.info( + "mjx_runner_ready", + num_envs=config.num_envs, + substeps=config.substeps, + jax_devices=str(jax.devices()), + ) + + def _make_batched_data(self, n: int): + """Create *n* environments with small random perturbations.""" + self._rng, k1, k2 = jax.random.split(self._rng, 3) + pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05) + pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05) + + default = self._default_mjx_data + model = self._mjx_model + + def _init_one(pq_i, pv_i): + d = default.replace( + qpos=default.qpos + pq_i, + qvel=default.qvel + pv_i, + ) + return mjx.forward(model, d) + + return jax.vmap(_init_one)(pq, pv) + + def _compile_jit_fns(self, substeps: int) -> None: + """Pre-compile the two hot-path functions so the first call is fast.""" + model = self._mjx_model + default = self._default_mjx_data + + # ── Batched step (N substeps per call) ────────────────────── + @jax.jit + def step_fn(data): + def body(_, d): + return jax.vmap(mjx.step, in_axes=(None, 0))(model, d) + + return jax.lax.fori_loop(0, substeps, body, data) + + self._jit_step = step_fn + + # ── Selective reset ───────────────────────────────────────── + @jax.jit + def reset_fn(data, smooth_ctrl, mask, rng): + rng, k1, k2 = jax.random.split(rng, 3) + ne = data.qpos.shape[0] + + pq = jax.random.uniform( + k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05, + ) + pv = jax.random.uniform( + k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05, + ) + + m = mask[:, None] # (num_envs, 1) broadcast helper + + new_qpos = jnp.where(m, default.qpos + pq, data.qpos) + new_qvel = jnp.where(m, default.qvel + pv, data.qvel) + new_ctrl = jnp.where(m, 0.0, data.ctrl) + new_smooth = jnp.where(m, 0.0, smooth_ctrl) + + new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl) + new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data) + + return new_data, new_smooth, rng + + self._jit_reset = reset_fn + + # ── Step / Reset ───────────────────────────────────────────────── + + def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # PyTorch → JAX (zero-copy on GPU via DLPack) + actions_jax = jnp.from_dlpack(actions.detach().contiguous()) + + # EMA smoothing (vectorised on GPU, no Python loop) + alpha = self.config.action_ema_alpha + self._smooth_ctrl = ( + alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl + ) + + # Set ctrl & run N substeps for all environments + self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl) + self._batch_data = self._jit_step(self._batch_data) + + # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) + qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32)) + qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32)) + return qpos, qvel + + def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # Build boolean mask (fixed shape → no JIT recompilation) + mask = torch.zeros( + self.config.num_envs, dtype=torch.bool, device=self.device, + ) + mask[env_ids] = True + mask_jax = jnp.from_dlpack(mask) + + self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset( + self._batch_data, self._smooth_ctrl, mask_jax, self._rng, + ) + + # Return only the reset environments' states + ids_np = env_ids.cpu().numpy() + rq = self._batch_data.qpos[ids_np].astype(jnp.float32) + rv = self._batch_data.qvel[ids_np].astype(jnp.float32) + return torch.from_dlpack(rq), torch.from_dlpack(rv) + + def _sim_close(self) -> None: + if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: + self._offscreen_renderer.close() + + # ── Rendering ──────────────────────────────────────────────────── + + def render(self, env_idx: int = 0) -> np.ndarray: + """Offscreen render — copies one env's state from GPU to CPU.""" + self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx]) + self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx]) + mujoco.mj_forward(self._mj_model, self._render_data) + + if not hasattr(self, "_offscreen_renderer"): + self._offscreen_renderer = mujoco.Renderer( + self._mj_model, width=640, height=480, + ) + self._offscreen_renderer.update_scene(self._render_data) + return self._offscreen_renderer.render() diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index f8c79b5..5e140c3 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -198,12 +198,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): def _sim_close(self) -> None: if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: self._offscreen_renderer.close() - self._offscreen_renderer = None - self._data.clear() - def render(self, env_idx: int = 0) -> np.ndarray | None: - """Offscreen render → RGB numpy array (H, W, 3).""" - if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None: - self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640) + def render(self, env_idx: int = 0) -> np.ndarray: + """Offscreen render of a single environment.""" + if not hasattr(self, "_offscreen_renderer"): + self._offscreen_renderer = mujoco.Renderer( + self._model, width=640, height=480, + ) + mujoco.mj_forward(self._model, self._data[env_idx]) self._offscreen_renderer.update_scene(self._data[env_idx]) - return self._offscreen_renderer.render().copy() \ No newline at end of file + return self._offscreen_renderer.render() \ No newline at end of file diff --git a/src/training/trainer.py b/src/training/trainer.py index 5230049..ce66e5f 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -4,6 +4,7 @@ import tempfile from pathlib import Path import numpy as np +import structlog import torch import tqdm from clearml import Logger @@ -16,6 +17,8 @@ from skrl.trainers.torch import SequentialTrainer from src.core.runner import BaseRunner from src.models.mlp import SharedMLP +log = structlog.get_logger() + @dataclasses.dataclass class TrainerConfig: @@ -125,7 +128,14 @@ class VideoRecordingTrainer(SequentialTrainer): action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] obs, _, terminated, truncated, _ = self.env.step(action) - frame = self.env.render() + try: + frame = self.env.render() + except Exception: + # Headless environment without OpenGL — skip video recording + log.warning("video_recording_disabled", reason="render failed (headless?)") + self.env.reset() + return + if frame is not None: frames.append(frame) diff --git a/train.py b/train.py index 8d16e80..ddcce35 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,10 @@ +import os import pathlib +import sys + +# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import) +if sys.platform == "linux" and "DISPLAY" not in os.environ: + os.environ.setdefault("MUJOCO_GL", "osmesa") import hydra import hydra.utils as hydra_utils @@ -8,9 +14,9 @@ from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig +from src.core.runner import BaseRunner from src.envs.cartpole import CartPoleEnv, CartPoleConfig from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig -from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig from src.training.trainer import Trainer, TrainerConfig logger = structlog.get_logger() @@ -41,6 +47,33 @@ def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: return env_cls(config_cls(**env_dict)) +# ── runner registry ─────────────────────────────────────────────────── +# Maps Hydra config-group name → (RunnerClass, ConfigClass) +# Imports are deferred so JAX is only loaded when runner=mjx is chosen. + +RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = { + "mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"), + "mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"), +} + + +def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner: + """Instantiate the right runner from the Hydra config-group name.""" + if runner_name not in RUNNER_REGISTRY: + raise ValueError( + f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}" + ) + module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name] + + import importlib + mod = importlib.import_module(module_path) + runner_cls = getattr(mod, cls_name) + config_cls = getattr(mod, cfg_cls_name) + + runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True)) + return runner_cls(env=env, config=runner_config) + + def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: """Initialize ClearML task with project structure and tags. @@ -58,7 +91,14 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: tags = [env_name, runner_name, training_name] task = Task.init(project_name=project, task_name=task_name, tags=tags) - task.set_base_docker("registry.kube.optimize/worker-image:latest") + task.set_base_docker( + "registry.kube.optimize/worker-image:latest", + docker_setup_bash_script=( + "apt-get update && apt-get install -y --no-install-recommends " + "libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* " + "&& pip install 'jax[cuda12]' mujoco-mjx" + ), + ) req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt" task.set_packages(str(req_file)) @@ -84,10 +124,8 @@ def main(cfg: DictConfig) -> None: env_name = choices.get("env", "cartpole") env = _build_env(env_name, cfg) - runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True)) + runner = _build_runner(choices.get("runner", "mujoco"), env, cfg) trainer_config = TrainerConfig(**training_dict) - - runner = MuJoCoRunner(env=env, config=runner_config) trainer = Trainer(runner=runner, config=trainer_config) try: