feat: full DR (friction/damping/torque) in MJX JIT step

This commit is contained in:
2026-06-09 21:25:05 +02:00
parent b37cd26690
commit 56499ebe97
2 changed files with 35 additions and 11 deletions

View File

@@ -7,11 +7,12 @@ history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
rma_mode: "none" # "none" | "teacher" | "deploy" rma_mode: "none" # "none" | "teacher" | "deploy"
# ── Domain randomization (sim-to-real) ────────────────────────────── # ── Domain randomization (sim-to-real) ──────────────────────────────
# NOTE: action-delay and sensor-noise are applied for MJX, but the # Full DR on GPU: latency + sensor noise + per-env dynamics scales
# per-env dynamics *scales* (friction/damping/torque) are NOT yet wired # (friction/damping/torque) are all applied inside the JIT step.
# into the JIT step — use runner=mujoco for scale randomization, or keep
# this block to delay+noise only on MJX.
domain_rand: domain_rand:
qpos_noise_std: 0.01 # rad — encoder angle noise qpos_noise_std: 0.01 # rad — encoder angle noise
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured) qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
action_delay_steps: [0, 2] # control-step latency (040 ms) action_delay_steps: [0, 2] # control-step latency (040 ms)
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
damping_scale: [0.6, 1.6] # viscous-damping multiplier
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation

View File

@@ -103,6 +103,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# Keep one CPU MjData for offscreen rendering # Keep one CPU MjData for offscreen rendering
self._render_data = mujoco.MjData(self._mj_model) self._render_data = mujoco.MjData(self._mj_model)
# Per-env DR scale arrays (synced from torch on every reset).
# Initialised to 1.0 here because _setup_domain_rand runs after this.
self._mjx_fr = jnp.ones(config.num_envs)
self._mjx_dp = jnp.ones(config.num_envs)
self._mjx_tq = jnp.ones(config.num_envs)
log.info( log.info(
"mjx_runner_ready", "mjx_runner_ready",
num_envs=config.num_envs, num_envs=config.num_envs,
@@ -159,8 +165,15 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
_back_emf = jnp.array([a.back_emf_gain for a in acts]) _back_emf = jnp.array([a.back_emf_gain for a in acts])
# ── Batched step (N substeps per call) ────────────────────── # ── Batched step (N substeps per call) ──────────────────────
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
# multipliers (1.0 = off). Passed as args (not closure constants) so
# resampling them every episode does NOT trigger JIT recompilation.
@jax.jit @jax.jit
def step_fn(data): def step_fn(data, fr_scale, dp_scale, tq_scale):
fr = fr_scale[:, None] # broadcast over motor actuators
dp = dp_scale[:, None]
tq = tq_scale[:, None]
# Software limit switch: clamp ctrl once before substeps. # Software limit switch: clamp ctrl once before substeps.
pos = data.qpos[:, act_jnt_ids] pos = data.qpos[:, act_jnt_ids]
ctrl = data.ctrl ctrl = data.ctrl
@@ -175,6 +188,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc) mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg) gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
mc = mc * gear_dir / _gear_avg mc = mc * gear_dir / _gear_avg
mc = mc * tq # torque_scale (DR)
ctrl = ctrl.at[:, _ctrl_ids].set(mc) ctrl = ctrl.at[:, _ctrl_ids].set(mc)
data = data.replace(ctrl=ctrl) data = data.replace(ctrl=ctrl)
@@ -184,13 +198,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
vel = d.qvel[:, _qvel_ids] vel = d.qvel[:, _qvel_ids]
mc = d.ctrl[:, _ctrl_ids] mc = d.ctrl[:, _ctrl_ids]
# Coulomb friction (direction-dependent) # Coulomb friction (direction-dependent) × DR scale
fl = jnp.where(vel > 0, _fl_pos, _fl_neg) fl = jnp.where(vel > 0, _fl_pos, _fl_neg) * fr
torque = -jnp.where( torque = -jnp.where(
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0, jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
) )
# Viscous damping (direction-dependent) # Viscous damping (direction-dependent) × DR scale
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
torque = torque - damp * vel torque = torque - damp * vel
# Quadratic velocity drag # Quadratic velocity drag
torque = torque - _visc_quad * vel * jnp.abs(vel) torque = torque - _visc_quad * vel * jnp.abs(vel)
@@ -242,9 +256,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# PyTorch → JAX (zero-copy on GPU via DLPack) # PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous()) actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# Set ctrl & run N substeps for all environments # Set ctrl & run N substeps for all environments (with per-env DR scales)
self._batch_data = self._batch_data.replace(ctrl=actions_jax) self._batch_data = self._batch_data.replace(ctrl=actions_jax)
self._batch_data = self._jit_step(self._batch_data) self._batch_data = self._jit_step(
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32)) qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
@@ -263,6 +279,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._batch_data, mask_jax, self._rng, self._batch_data, mask_jax, self._rng,
) )
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
# resamples self._dr_scales just before this call, so re-deriving the
# full arrays here keeps the JAX copies current for every env.
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
# Return only the reset environments' states # Return only the reset environments' states
ids_np = env_ids.cpu().numpy() ids_np = env_ids.cpu().numpy()
rq = self._batch_data.qpos[ids_np].astype(jnp.float32) rq = self._batch_data.qpos[ids_np].astype(jnp.float32)