"""GPU-batched MuJoCo simulation using MJX (JAX backend). MJX runs all environments in parallel on GPU via JAX, providing ~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+). Requirements: pip install 'jax[cuda12]' # NVIDIA GPU pip install jax # CPU fallback """ import dataclasses import structlog import torch try: import jax import jax.numpy as jnp import mujoco from mujoco import mjx except ImportError as e: raise ImportError( "MJX runner requires JAX and MuJoCo MJX. Install with:\n" " pip install 'jax[cuda12]' # GPU\n" " pip install jax # CPU\n" ) from e import numpy as np from src.core.env import BaseEnv from src.core.runner import BaseRunner, BaseRunnerConfig from src.runners.mujoco import MuJoCoRunner # reuse _load_model log = structlog.get_logger() @dataclasses.dataclass class MJXRunnerConfig(BaseRunnerConfig): num_envs: int = 1024 device: str = "cuda" dt: float = 0.002 substeps: int = 20 action_ema_alpha: float = 0.2 class MJXRunner(BaseRunner[MJXRunnerConfig]): """GPU-batched MuJoCo runner using MJX (JAX). Physics runs entirely on GPU via JAX; observations flow to PyTorch through zero-copy DLPack transfers. """ def __init__(self, env: BaseEnv, config: MJXRunnerConfig): super().__init__(env, config) @property def num_envs(self) -> int: return self.config.num_envs @property def device(self) -> torch.device: return torch.device(self.config.device) # ── Initialization ─────────────────────────────────────────────── def _sim_initialize(self, config: MJXRunnerConfig) -> None: # Step 1: Load CPU model (reuses URDF → MJCF → actuator injection) self._mj_model = MuJoCoRunner._load_model(self.env.robot) self._mj_model.opt.timestep = config.dt self._nq = self._mj_model.nq self._nv = self._mj_model.nv self._nu = self._mj_model.nu # Step 2: Put model on GPU self._mjx_model = mjx.put_model(self._mj_model) # Step 3: Default reset state on GPU default_data = mujoco.MjData(self._mj_model) default_qpos = self.env.get_default_qpos(self._nq) if default_qpos is not None: default_data.qpos[:] = default_qpos mujoco.mj_forward(self._mj_model, default_data) self._default_mjx_data = mjx.put_data(self._mj_model, default_data) # Step 4: Initialise all environments with small perturbations self._rng = jax.random.PRNGKey(42) self._batch_data = self._make_batched_data(config.num_envs) # Step 5: EMA ctrl state (on GPU as JAX array) self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu)) # Step 6: JIT-compile the hot-path functions self._compile_jit_fns(config.substeps) # Keep one CPU MjData for offscreen rendering self._render_data = mujoco.MjData(self._mj_model) log.info( "mjx_runner_ready", num_envs=config.num_envs, substeps=config.substeps, jax_devices=str(jax.devices()), ) def _make_batched_data(self, n: int): """Create *n* environments with small random perturbations.""" self._rng, k1, k2 = jax.random.split(self._rng, 3) pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05) pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05) default = self._default_mjx_data model = self._mjx_model def _init_one(pq_i, pv_i): d = default.replace( qpos=default.qpos + pq_i, qvel=default.qvel + pv_i, ) return mjx.forward(model, d) return jax.vmap(_init_one)(pq, pv) def _compile_jit_fns(self, substeps: int) -> None: """Pre-compile the two hot-path functions so the first call is fast.""" model = self._mjx_model default = self._default_mjx_data # ── Batched step (N substeps per call) ────────────────────── @jax.jit def step_fn(data): def body(_, d): return jax.vmap(mjx.step, in_axes=(None, 0))(model, d) return jax.lax.fori_loop(0, substeps, body, data) self._jit_step = step_fn # ── Selective reset ───────────────────────────────────────── @jax.jit def reset_fn(data, smooth_ctrl, mask, rng): rng, k1, k2 = jax.random.split(rng, 3) ne = data.qpos.shape[0] pq = jax.random.uniform( k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05, ) pv = jax.random.uniform( k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05, ) m = mask[:, None] # (num_envs, 1) broadcast helper new_qpos = jnp.where(m, default.qpos + pq, data.qpos) new_qvel = jnp.where(m, default.qvel + pv, data.qvel) new_ctrl = jnp.where(m, 0.0, data.ctrl) new_smooth = jnp.where(m, 0.0, smooth_ctrl) new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl) new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data) return new_data, new_smooth, rng self._jit_reset = reset_fn # ── Step / Reset ───────────────────────────────────────────────── def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # PyTorch → JAX (zero-copy on GPU via DLPack) actions_jax = jnp.from_dlpack(actions.detach().contiguous()) # EMA smoothing (vectorised on GPU, no Python loop) alpha = self.config.action_ema_alpha self._smooth_ctrl = ( alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl ) # Set ctrl & run N substeps for all environments self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl) self._batch_data = self._jit_step(self._batch_data) # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32)) qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32)) return qpos, qvel def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Build boolean mask (fixed shape → no JIT recompilation) mask = torch.zeros( self.config.num_envs, dtype=torch.bool, device=self.device, ) mask[env_ids] = True mask_jax = jnp.from_dlpack(mask) self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset( self._batch_data, self._smooth_ctrl, mask_jax, self._rng, ) # Return only the reset environments' states ids_np = env_ids.cpu().numpy() rq = self._batch_data.qpos[ids_np].astype(jnp.float32) rv = self._batch_data.qvel[ids_np].astype(jnp.float32) return torch.from_dlpack(rq), torch.from_dlpack(rv) # ── Rendering ──────────────────────────────────────────────────── def render(self, env_idx: int = 0) -> np.ndarray: """Offscreen render — copies one env's state from GPU to CPU.""" self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx]) self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx]) mujoco.mj_forward(self._mj_model, self._render_data) if not hasattr(self, "_offscreen_renderer"): self._offscreen_renderer = mujoco.Renderer( self._mj_model, width=640, height=480, ) self._offscreen_renderer.update_scene(self._render_data) return self._offscreen_renderer.render()