Files
RL-Sim-Framework/src/runners/mjx.py
2026-03-09 22:17:28 +01:00

218 lines
8.1 KiB
Python

"""GPU-batched MuJoCo simulation using MJX (JAX backend).
MJX runs all environments in parallel on GPU via JAX, providing
~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+).
Requirements:
pip install 'jax[cuda12]' # NVIDIA GPU
pip install jax # CPU fallback
"""
import dataclasses
import structlog
import torch
try:
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
except ImportError as e:
raise ImportError(
"MJX runner requires JAX and MuJoCo MJX. Install with:\n"
" pip install 'jax[cuda12]' # GPU\n"
" pip install jax # CPU\n"
) from e
import numpy as np
from src.core.env import BaseEnv
from src.core.runner import BaseRunner, BaseRunnerConfig
from src.runners.mujoco import MuJoCoRunner # reuse _load_model
log = structlog.get_logger()
@dataclasses.dataclass
class MJXRunnerConfig(BaseRunnerConfig):
num_envs: int = 1024
device: str = "cuda"
dt: float = 0.002
substeps: int = 20
action_ema_alpha: float = 0.2
class MJXRunner(BaseRunner[MJXRunnerConfig]):
"""GPU-batched MuJoCo runner using MJX (JAX).
Physics runs entirely on GPU via JAX; observations flow to
PyTorch through zero-copy DLPack transfers.
"""
def __init__(self, env: BaseEnv, config: MJXRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
# ── Initialization ───────────────────────────────────────────────
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
self._mj_model = MuJoCoRunner._load_model(self.env.robot)
self._mj_model.opt.timestep = config.dt
self._nq = self._mj_model.nq
self._nv = self._mj_model.nv
self._nu = self._mj_model.nu
# Step 2: Put model on GPU
self._mjx_model = mjx.put_model(self._mj_model)
# Step 3: Default reset state on GPU
default_data = mujoco.MjData(self._mj_model)
default_qpos = self.env.get_default_qpos(self._nq)
if default_qpos is not None:
default_data.qpos[:] = default_qpos
mujoco.mj_forward(self._mj_model, default_data)
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
# Step 4: Initialise all environments with small perturbations
self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs)
# Step 5: EMA ctrl state (on GPU as JAX array)
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
# Step 6: JIT-compile the hot-path functions
self._compile_jit_fns(config.substeps)
# Keep one CPU MjData for offscreen rendering
self._render_data = mujoco.MjData(self._mj_model)
log.info(
"mjx_runner_ready",
num_envs=config.num_envs,
substeps=config.substeps,
jax_devices=str(jax.devices()),
)
def _make_batched_data(self, n: int):
"""Create *n* environments with small random perturbations."""
self._rng, k1, k2 = jax.random.split(self._rng, 3)
pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05)
pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05)
default = self._default_mjx_data
model = self._mjx_model
def _init_one(pq_i, pv_i):
d = default.replace(
qpos=default.qpos + pq_i,
qvel=default.qvel + pv_i,
)
return mjx.forward(model, d)
return jax.vmap(_init_one)(pq, pv)
def _compile_jit_fns(self, substeps: int) -> None:
"""Pre-compile the two hot-path functions so the first call is fast."""
model = self._mjx_model
default = self._default_mjx_data
# ── Batched step (N substeps per call) ──────────────────────
@jax.jit
def step_fn(data):
def body(_, d):
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
return jax.lax.fori_loop(0, substeps, body, data)
self._jit_step = step_fn
# ── Selective reset ─────────────────────────────────────────
@jax.jit
def reset_fn(data, smooth_ctrl, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0]
pq = jax.random.uniform(
k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05,
)
pv = jax.random.uniform(
k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05,
)
m = mask[:, None] # (num_envs, 1) broadcast helper
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
new_ctrl = jnp.where(m, 0.0, data.ctrl)
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
return new_data, new_smooth, rng
self._jit_reset = reset_fn
# ── Step / Reset ─────────────────────────────────────────────────
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# EMA smoothing (vectorised on GPU, no Python loop)
alpha = self.config.action_ema_alpha
self._smooth_ctrl = (
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
)
# Set ctrl & run N substeps for all environments
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
self._batch_data = self._jit_step(self._batch_data)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
return qpos, qvel
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Build boolean mask (fixed shape → no JIT recompilation)
mask = torch.zeros(
self.config.num_envs, dtype=torch.bool, device=self.device,
)
mask[env_ids] = True
mask_jax = jnp.from_dlpack(mask)
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
)
# Return only the reset environments' states
ids_np = env_ids.cpu().numpy()
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
return torch.from_dlpack(rq), torch.from_dlpack(rv)
# ── Rendering ────────────────────────────────────────────────────
def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render — copies one env's state from GPU to CPU."""
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
mujoco.mj_forward(self._mj_model, self._render_data)
if not hasattr(self, "_offscreen_renderer"):
self._offscreen_renderer = mujoco.Renderer(
self._mj_model, width=640, height=480,
)
self._offscreen_renderer.update_scene(self._render_data)
return self._offscreen_renderer.render()