add mjx runner

This commit is contained in:
2026-03-09 21:18:19 +01:00
parent 15da0ef2fd
commit 26ccb1e902
8 changed files with 301 additions and 14 deletions

5
configs/runner/mjx.yaml Normal file
View File

@@ -0,0 +1,5 @@
num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu
dt: 0.002
substeps: 20
action_ema_alpha: 0.2

View File

@@ -1,5 +1,5 @@
num_envs: 64 num_envs: 64
device: cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 20 substeps: 20
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse) action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)

View File

@@ -3,6 +3,8 @@ gymnasium
hydra-core hydra-core
omegaconf omegaconf
mujoco mujoco
mujoco-mjx
jax
skrl[torch] skrl[torch]
clearml clearml
imageio imageio

View File

@@ -17,6 +17,10 @@ class BaseRunner(abc.ABC, Generic[T]):
self.env = env self.env = env
self.config = config self.config = config
# Resolve "auto" device before anything uses it
if getattr(self.config, "device", None) == "auto":
self.config.device = "cuda" if torch.cuda.is_available() else "cpu"
self._sim_initialize(config) self._sim_initialize(config)
self.observation_space = self.env.observation_space self.observation_space = self.env.observation_space

227
src/runners/mjx.py Normal file
View File

@@ -0,0 +1,227 @@
"""GPU-batched MuJoCo simulation using MJX (JAX backend).
MJX runs all environments in parallel on GPU via JAX, providing
~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+).
Requirements:
pip install 'jax[cuda12]' # NVIDIA GPU
pip install jax # CPU fallback
"""
import dataclasses
import structlog
import torch
try:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
except ImportError as e:
raise ImportError(
"MJX runner requires JAX and MuJoCo MJX. Install with:\n"
" pip install 'jax[cuda12]' # GPU\n"
" pip install jax # CPU\n"
) from e
import numpy as np
from src.core.env import BaseEnv
from src.core.runner import BaseRunner, BaseRunnerConfig
from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators
log = structlog.get_logger()
@dataclasses.dataclass
class MJXRunnerConfig(BaseRunnerConfig):
num_envs: int = 1024
device: str = "cuda"
dt: float = 0.002
substeps: int = 20
action_ema_alpha: float = 0.2
class MJXRunner(BaseRunner[MJXRunnerConfig]):
"""GPU-batched MuJoCo runner using MJX (JAX).
Physics runs entirely on GPU via JAX; observations flow to
PyTorch through zero-copy DLPack transfers.
"""
def __init__(self, env: BaseEnv, config: MJXRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
# ── Initialization ───────────────────────────────────────────────
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified")
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
self._mj_model = MuJoCoRunner._load_model_with_actuators(
str(model_path), self.env.config.actuators,
)
self._mj_model.opt.timestep = config.dt
self._nq = self._mj_model.nq
self._nv = self._mj_model.nv
self._nu = self._mj_model.nu
# Step 2: Put model on GPU
self._mjx_model = mjx.put_model(self._mj_model)
# Step 3: Default reset state on GPU
default_data = mujoco.MjData(self._mj_model)
default_qpos = self.env.get_default_qpos(self._nq)
if default_qpos is not None:
default_data.qpos[:] = default_qpos
mujoco.mj_forward(self._mj_model, default_data)
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
# Step 4: Initialise all environments with small perturbations
self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs)
# Step 5: EMA ctrl state (on GPU as JAX array)
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
# Step 6: JIT-compile the hot-path functions
self._compile_jit_fns(config.substeps)
# Keep one CPU MjData for offscreen rendering
self._render_data = mujoco.MjData(self._mj_model)
log.info(
"mjx_runner_ready",
num_envs=config.num_envs,
substeps=config.substeps,
jax_devices=str(jax.devices()),
)
def _make_batched_data(self, n: int):
"""Create *n* environments with small random perturbations."""
self._rng, k1, k2 = jax.random.split(self._rng, 3)
pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05)
pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05)
default = self._default_mjx_data
model = self._mjx_model
def _init_one(pq_i, pv_i):
d = default.replace(
qpos=default.qpos + pq_i,
qvel=default.qvel + pv_i,
)
return mjx.forward(model, d)
return jax.vmap(_init_one)(pq, pv)
def _compile_jit_fns(self, substeps: int) -> None:
"""Pre-compile the two hot-path functions so the first call is fast."""
model = self._mjx_model
default = self._default_mjx_data
# ── Batched step (N substeps per call) ──────────────────────
@jax.jit
def step_fn(data):
def body(_, d):
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
return jax.lax.fori_loop(0, substeps, body, data)
self._jit_step = step_fn
# ── Selective reset ─────────────────────────────────────────
@jax.jit
def reset_fn(data, smooth_ctrl, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0]
pq = jax.random.uniform(
k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05,
)
pv = jax.random.uniform(
k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05,
)
m = mask[:, None] # (num_envs, 1) broadcast helper
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
new_ctrl = jnp.where(m, 0.0, data.ctrl)
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
return new_data, new_smooth, rng
self._jit_reset = reset_fn
# ── Step / Reset ─────────────────────────────────────────────────
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# EMA smoothing (vectorised on GPU, no Python loop)
alpha = self.config.action_ema_alpha
self._smooth_ctrl = (
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
)
# Set ctrl & run N substeps for all environments
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
self._batch_data = self._jit_step(self._batch_data)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
return qpos, qvel
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Build boolean mask (fixed shape → no JIT recompilation)
mask = torch.zeros(
self.config.num_envs, dtype=torch.bool, device=self.device,
)
mask[env_ids] = True
mask_jax = jnp.from_dlpack(mask)
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
)
# Return only the reset environments' states
ids_np = env_ids.cpu().numpy()
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
return torch.from_dlpack(rq), torch.from_dlpack(rv)
def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
# ── Rendering ────────────────────────────────────────────────────
def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render — copies one env's state from GPU to CPU."""
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
mujoco.mj_forward(self._mj_model, self._render_data)
if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(
self._mj_model, width=640, height=480,
)
self._offscreen_renderer.update_scene(self._render_data)
return self._offscreen_renderer.render()

View File

@@ -198,12 +198,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def _sim_close(self) -> None: def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close() self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0) -> np.ndarray | None: def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render → RGB numpy array (H, W, 3).""" """Offscreen render of a single environment."""
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None: if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640) self._offscreen_renderer = mujoco.Renderer(
self._model, width=640, height=480,
)
mujoco.mj_forward(self._model, self._data[env_idx])
self._offscreen_renderer.update_scene(self._data[env_idx]) self._offscreen_renderer.update_scene(self._data[env_idx])
return self._offscreen_renderer.render().copy() return self._offscreen_renderer.render()

View File

@@ -4,6 +4,7 @@ import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import structlog
import torch import torch
import tqdm import tqdm
from clearml import Logger from clearml import Logger
@@ -16,6 +17,8 @@ from skrl.trainers.torch import SequentialTrainer
from src.core.runner import BaseRunner from src.core.runner import BaseRunner
from src.models.mlp import SharedMLP from src.models.mlp import SharedMLP
log = structlog.get_logger()
@dataclasses.dataclass @dataclasses.dataclass
class TrainerConfig: class TrainerConfig:
@@ -125,7 +128,14 @@ class VideoRecordingTrainer(SequentialTrainer):
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action) obs, _, terminated, truncated, _ = self.env.step(action)
try:
frame = self.env.render() frame = self.env.render()
except Exception:
# Headless environment without OpenGL — skip video recording
log.warning("video_recording_disabled", reason="render failed (headless?)")
self.env.reset()
return
if frame is not None: if frame is not None:
frames.append(frame) frames.append(frame)

View File

@@ -1,4 +1,10 @@
import os
import pathlib import pathlib
import sys
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
if sys.platform == "linux" and "DISPLAY" not in os.environ:
os.environ.setdefault("MUJOCO_GL", "osmesa")
import hydra import hydra
import hydra.utils as hydra_utils import hydra.utils as hydra_utils
@@ -8,9 +14,9 @@ from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
from src.core.runner import BaseRunner
from src.envs.cartpole import CartPoleEnv, CartPoleConfig from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
from src.training.trainer import Trainer, TrainerConfig from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -41,6 +47,33 @@ def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
return env_cls(config_cls(**env_dict)) return env_cls(config_cls(**env_dict))
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
}
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
"""Instantiate the right runner from the Hydra config-group name."""
if runner_name not in RUNNER_REGISTRY:
raise ValueError(
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
)
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
import importlib
mod = importlib.import_module(module_path)
runner_cls = getattr(mod, cls_name)
config_cls = getattr(mod, cfg_cls_name)
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
return runner_cls(env=env, config=runner_config)
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags. """Initialize ClearML task with project structure and tags.
@@ -58,7 +91,14 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
tags = [env_name, runner_name, training_name] tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags) task = Task.init(project_name=project, task_name=task_name, tags=tags)
task.set_base_docker("registry.kube.optimize/worker-image:latest") task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]' mujoco-mjx"
),
)
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt" req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
task.set_packages(str(req_file)) task.set_packages(str(req_file))
@@ -84,10 +124,8 @@ def main(cfg: DictConfig) -> None:
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "cartpole")
env = _build_env(env_name, cfg) env = _build_env(env_name, cfg)
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True)) runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict) trainer_config = TrainerConfig(**training_dict)
runner = MuJoCoRunner(env=env, config=runner_config)
trainer = Trainer(runner=runner, config=trainer_config) trainer = Trainer(runner=runner, config=trainer_config)
try: try: