✨ add mjx runner
This commit is contained in:
5
configs/runner/mjx.yaml
Normal file
5
configs/runner/mjx.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
num_envs: 1024 # MJX shines with many parallel envs
|
||||||
|
device: auto # auto = cuda if available, else cpu
|
||||||
|
dt: 0.002
|
||||||
|
substeps: 20
|
||||||
|
action_ema_alpha: 0.2
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
num_envs: 64
|
num_envs: 64
|
||||||
device: cpu
|
device: auto # auto = cuda if available, else cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 20
|
substeps: 20
|
||||||
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ gymnasium
|
|||||||
hydra-core
|
hydra-core
|
||||||
omegaconf
|
omegaconf
|
||||||
mujoco
|
mujoco
|
||||||
|
mujoco-mjx
|
||||||
|
jax
|
||||||
skrl[torch]
|
skrl[torch]
|
||||||
clearml
|
clearml
|
||||||
imageio
|
imageio
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ class BaseRunner(abc.ABC, Generic[T]):
|
|||||||
self.env = env
|
self.env = env
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# Resolve "auto" device before anything uses it
|
||||||
|
if getattr(self.config, "device", None) == "auto":
|
||||||
|
self.config.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
self._sim_initialize(config)
|
self._sim_initialize(config)
|
||||||
|
|
||||||
self.observation_space = self.env.observation_space
|
self.observation_space = self.env.observation_space
|
||||||
|
|||||||
227
src/runners/mjx.py
Normal file
227
src/runners/mjx.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
"""GPU-batched MuJoCo simulation using MJX (JAX backend).
|
||||||
|
|
||||||
|
MJX runs all environments in parallel on GPU via JAX, providing
|
||||||
|
~10-100x speedup over the CPU MuJoCoRunner for large env counts (1024+).
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install 'jax[cuda12]' # NVIDIA GPU
|
||||||
|
pip install jax # CPU fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mujoco
|
||||||
|
from mujoco import mjx
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"MJX runner requires JAX and MuJoCo MJX. Install with:\n"
|
||||||
|
" pip install 'jax[cuda12]' # GPU\n"
|
||||||
|
" pip install jax # CPU\n"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.core.env import BaseEnv
|
||||||
|
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||||
|
from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MJXRunnerConfig(BaseRunnerConfig):
|
||||||
|
num_envs: int = 1024
|
||||||
|
device: str = "cuda"
|
||||||
|
dt: float = 0.002
|
||||||
|
substeps: int = 20
|
||||||
|
action_ema_alpha: float = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||||
|
"""GPU-batched MuJoCo runner using MJX (JAX).
|
||||||
|
|
||||||
|
Physics runs entirely on GPU via JAX; observations flow to
|
||||||
|
PyTorch through zero-copy DLPack transfers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: BaseEnv, config: MJXRunnerConfig):
|
||||||
|
super().__init__(env, config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_envs(self) -> int:
|
||||||
|
return self.config.num_envs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(self.config.device)
|
||||||
|
|
||||||
|
# ── Initialization ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _sim_initialize(self, config: MJXRunnerConfig) -> None:
|
||||||
|
model_path = self.env.config.model_path
|
||||||
|
if model_path is None:
|
||||||
|
raise ValueError("model_path must be specified")
|
||||||
|
|
||||||
|
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
|
||||||
|
self._mj_model = MuJoCoRunner._load_model_with_actuators(
|
||||||
|
str(model_path), self.env.config.actuators,
|
||||||
|
)
|
||||||
|
self._mj_model.opt.timestep = config.dt
|
||||||
|
self._nq = self._mj_model.nq
|
||||||
|
self._nv = self._mj_model.nv
|
||||||
|
self._nu = self._mj_model.nu
|
||||||
|
|
||||||
|
# Step 2: Put model on GPU
|
||||||
|
self._mjx_model = mjx.put_model(self._mj_model)
|
||||||
|
|
||||||
|
# Step 3: Default reset state on GPU
|
||||||
|
default_data = mujoco.MjData(self._mj_model)
|
||||||
|
default_qpos = self.env.get_default_qpos(self._nq)
|
||||||
|
if default_qpos is not None:
|
||||||
|
default_data.qpos[:] = default_qpos
|
||||||
|
mujoco.mj_forward(self._mj_model, default_data)
|
||||||
|
self._default_mjx_data = mjx.put_data(self._mj_model, default_data)
|
||||||
|
|
||||||
|
# Step 4: Initialise all environments with small perturbations
|
||||||
|
self._rng = jax.random.PRNGKey(42)
|
||||||
|
self._batch_data = self._make_batched_data(config.num_envs)
|
||||||
|
|
||||||
|
# Step 5: EMA ctrl state (on GPU as JAX array)
|
||||||
|
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
|
||||||
|
|
||||||
|
# Step 6: JIT-compile the hot-path functions
|
||||||
|
self._compile_jit_fns(config.substeps)
|
||||||
|
|
||||||
|
# Keep one CPU MjData for offscreen rendering
|
||||||
|
self._render_data = mujoco.MjData(self._mj_model)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"mjx_runner_ready",
|
||||||
|
num_envs=config.num_envs,
|
||||||
|
substeps=config.substeps,
|
||||||
|
jax_devices=str(jax.devices()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_batched_data(self, n: int):
|
||||||
|
"""Create *n* environments with small random perturbations."""
|
||||||
|
self._rng, k1, k2 = jax.random.split(self._rng, 3)
|
||||||
|
pq = jax.random.uniform(k1, (n, self._nq), minval=-0.05, maxval=0.05)
|
||||||
|
pv = jax.random.uniform(k2, (n, self._nv), minval=-0.05, maxval=0.05)
|
||||||
|
|
||||||
|
default = self._default_mjx_data
|
||||||
|
model = self._mjx_model
|
||||||
|
|
||||||
|
def _init_one(pq_i, pv_i):
|
||||||
|
d = default.replace(
|
||||||
|
qpos=default.qpos + pq_i,
|
||||||
|
qvel=default.qvel + pv_i,
|
||||||
|
)
|
||||||
|
return mjx.forward(model, d)
|
||||||
|
|
||||||
|
return jax.vmap(_init_one)(pq, pv)
|
||||||
|
|
||||||
|
def _compile_jit_fns(self, substeps: int) -> None:
|
||||||
|
"""Pre-compile the two hot-path functions so the first call is fast."""
|
||||||
|
model = self._mjx_model
|
||||||
|
default = self._default_mjx_data
|
||||||
|
|
||||||
|
# ── Batched step (N substeps per call) ──────────────────────
|
||||||
|
@jax.jit
|
||||||
|
def step_fn(data):
|
||||||
|
def body(_, d):
|
||||||
|
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
|
||||||
|
|
||||||
|
return jax.lax.fori_loop(0, substeps, body, data)
|
||||||
|
|
||||||
|
self._jit_step = step_fn
|
||||||
|
|
||||||
|
# ── Selective reset ─────────────────────────────────────────
|
||||||
|
@jax.jit
|
||||||
|
def reset_fn(data, smooth_ctrl, mask, rng):
|
||||||
|
rng, k1, k2 = jax.random.split(rng, 3)
|
||||||
|
ne = data.qpos.shape[0]
|
||||||
|
|
||||||
|
pq = jax.random.uniform(
|
||||||
|
k1, (ne, default.qpos.shape[0]), minval=-0.05, maxval=0.05,
|
||||||
|
)
|
||||||
|
pv = jax.random.uniform(
|
||||||
|
k2, (ne, default.qvel.shape[0]), minval=-0.05, maxval=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
m = mask[:, None] # (num_envs, 1) broadcast helper
|
||||||
|
|
||||||
|
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
|
||||||
|
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
|
||||||
|
new_ctrl = jnp.where(m, 0.0, data.ctrl)
|
||||||
|
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
|
||||||
|
|
||||||
|
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
|
||||||
|
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
|
||||||
|
|
||||||
|
return new_data, new_smooth, rng
|
||||||
|
|
||||||
|
self._jit_reset = reset_fn
|
||||||
|
|
||||||
|
# ── Step / Reset ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||||
|
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||||
|
|
||||||
|
# EMA smoothing (vectorised on GPU, no Python loop)
|
||||||
|
alpha = self.config.action_ema_alpha
|
||||||
|
self._smooth_ctrl = (
|
||||||
|
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set ctrl & run N substeps for all environments
|
||||||
|
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
|
||||||
|
self._batch_data = self._jit_step(self._batch_data)
|
||||||
|
|
||||||
|
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||||
|
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||||
|
qvel = torch.from_dlpack(self._batch_data.qvel.astype(jnp.float32))
|
||||||
|
return qpos, qvel
|
||||||
|
|
||||||
|
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Build boolean mask (fixed shape → no JIT recompilation)
|
||||||
|
mask = torch.zeros(
|
||||||
|
self.config.num_envs, dtype=torch.bool, device=self.device,
|
||||||
|
)
|
||||||
|
mask[env_ids] = True
|
||||||
|
mask_jax = jnp.from_dlpack(mask)
|
||||||
|
|
||||||
|
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
|
||||||
|
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return only the reset environments' states
|
||||||
|
ids_np = env_ids.cpu().numpy()
|
||||||
|
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
|
||||||
|
rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
|
||||||
|
return torch.from_dlpack(rq), torch.from_dlpack(rv)
|
||||||
|
|
||||||
|
def _sim_close(self) -> None:
|
||||||
|
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||||
|
self._offscreen_renderer.close()
|
||||||
|
|
||||||
|
# ── Rendering ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||||
|
"""Offscreen render — copies one env's state from GPU to CPU."""
|
||||||
|
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
|
||||||
|
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
|
||||||
|
mujoco.mj_forward(self._mj_model, self._render_data)
|
||||||
|
|
||||||
|
if not hasattr(self, "_offscreen_renderer"):
|
||||||
|
self._offscreen_renderer = mujoco.Renderer(
|
||||||
|
self._mj_model, width=640, height=480,
|
||||||
|
)
|
||||||
|
self._offscreen_renderer.update_scene(self._render_data)
|
||||||
|
return self._offscreen_renderer.render()
|
||||||
@@ -198,12 +198,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
def _sim_close(self) -> None:
|
def _sim_close(self) -> None:
|
||||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||||
self._offscreen_renderer.close()
|
self._offscreen_renderer.close()
|
||||||
self._offscreen_renderer = None
|
|
||||||
self._data.clear()
|
|
||||||
|
|
||||||
def render(self, env_idx: int = 0) -> np.ndarray | None:
|
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||||
"""Offscreen render → RGB numpy array (H, W, 3)."""
|
"""Offscreen render of a single environment."""
|
||||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
if not hasattr(self, "_offscreen_renderer"):
|
||||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
self._offscreen_renderer = mujoco.Renderer(
|
||||||
|
self._model, width=640, height=480,
|
||||||
|
)
|
||||||
|
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||||
return self._offscreen_renderer.render().copy()
|
return self._offscreen_renderer.render()
|
||||||
@@ -4,6 +4,7 @@ import tempfile
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import structlog
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from clearml import Logger
|
from clearml import Logger
|
||||||
@@ -16,6 +17,8 @@ from skrl.trainers.torch import SequentialTrainer
|
|||||||
from src.core.runner import BaseRunner
|
from src.core.runner import BaseRunner
|
||||||
from src.models.mlp import SharedMLP
|
from src.models.mlp import SharedMLP
|
||||||
|
|
||||||
|
log = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class TrainerConfig:
|
class TrainerConfig:
|
||||||
@@ -125,7 +128,14 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||||
|
|
||||||
|
try:
|
||||||
frame = self.env.render()
|
frame = self.env.render()
|
||||||
|
except Exception:
|
||||||
|
# Headless environment without OpenGL — skip video recording
|
||||||
|
log.warning("video_recording_disabled", reason="render failed (headless?)")
|
||||||
|
self.env.reset()
|
||||||
|
return
|
||||||
|
|
||||||
if frame is not None:
|
if frame is not None:
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
|
|
||||||
|
|||||||
48
train.py
48
train.py
@@ -1,4 +1,10 @@
|
|||||||
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||||
|
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||||
|
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import hydra.utils as hydra_utils
|
import hydra.utils as hydra_utils
|
||||||
@@ -8,9 +14,9 @@ from hydra.core.hydra_config import HydraConfig
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||||
|
from src.core.runner import BaseRunner
|
||||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
||||||
from src.training.trainer import Trainer, TrainerConfig
|
from src.training.trainer import Trainer, TrainerConfig
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -41,6 +47,33 @@ def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
|||||||
return env_cls(config_cls(**env_dict))
|
return env_cls(config_cls(**env_dict))
|
||||||
|
|
||||||
|
|
||||||
|
# ── runner registry ───────────────────────────────────────────────────
|
||||||
|
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||||
|
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||||
|
|
||||||
|
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||||
|
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||||
|
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||||
|
"""Instantiate the right runner from the Hydra config-group name."""
|
||||||
|
if runner_name not in RUNNER_REGISTRY:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||||
|
)
|
||||||
|
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
runner_cls = getattr(mod, cls_name)
|
||||||
|
config_cls = getattr(mod, cfg_cls_name)
|
||||||
|
|
||||||
|
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||||
|
return runner_cls(env=env, config=runner_config)
|
||||||
|
|
||||||
|
|
||||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||||
"""Initialize ClearML task with project structure and tags.
|
"""Initialize ClearML task with project structure and tags.
|
||||||
|
|
||||||
@@ -58,7 +91,14 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|||||||
tags = [env_name, runner_name, training_name]
|
tags = [env_name, runner_name, training_name]
|
||||||
|
|
||||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||||
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
task.set_base_docker(
|
||||||
|
"registry.kube.optimize/worker-image:latest",
|
||||||
|
docker_setup_bash_script=(
|
||||||
|
"apt-get update && apt-get install -y --no-install-recommends "
|
||||||
|
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||||
|
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||||
task.set_packages(str(req_file))
|
task.set_packages(str(req_file))
|
||||||
@@ -84,10 +124,8 @@ def main(cfg: DictConfig) -> None:
|
|||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "cartpole")
|
||||||
env = _build_env(env_name, cfg)
|
env = _build_env(env_name, cfg)
|
||||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||||
trainer_config = TrainerConfig(**training_dict)
|
trainer_config = TrainerConfig(**training_dict)
|
||||||
|
|
||||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
|
||||||
trainer = Trainer(runner=runner, config=trainer_config)
|
trainer = Trainer(runner=runner, config=trainer_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user