diff --git a/configs/runner/mjx.yaml b/configs/runner/mjx.yaml index f695009..87f225c 100644 --- a/configs/runner/mjx.yaml +++ b/configs/runner/mjx.yaml @@ -7,11 +7,12 @@ history_length: 10 # RMA-style: 10-step window of (obs, action) pairs rma_mode: "none" # "none" | "teacher" | "deploy" # ── Domain randomization (sim-to-real) ────────────────────────────── -# NOTE: action-delay and sensor-noise are applied for MJX, but the -# per-env dynamics *scales* (friction/damping/torque) are NOT yet wired -# into the JIT step — use runner=mujoco for scale randomization, or keep -# this block to delay+noise only on MJX. +# Full DR on GPU: latency + sensor noise + per-env dynamics scales +# (friction/damping/torque) are all applied inside the JIT step. domain_rand: qpos_noise_std: 0.01 # rad — encoder angle noise qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured) action_delay_steps: [0, 2] # control-step latency (0–40 ms) + friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env) + damping_scale: [0.6, 1.6] # viscous-damping multiplier + torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation diff --git a/src/runners/mjx.py b/src/runners/mjx.py index b67ace9..ed02079 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -103,6 +103,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # Keep one CPU MjData for offscreen rendering self._render_data = mujoco.MjData(self._mj_model) + # Per-env DR scale arrays (synced from torch on every reset). + # Initialised to 1.0 here because _setup_domain_rand runs after this. + self._mjx_fr = jnp.ones(config.num_envs) + self._mjx_dp = jnp.ones(config.num_envs) + self._mjx_tq = jnp.ones(config.num_envs) + log.info( "mjx_runner_ready", num_envs=config.num_envs, @@ -159,8 +165,15 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): _back_emf = jnp.array([a.back_emf_gain for a in acts]) # ── Batched step (N substeps per call) ────────────────────── + # fr/dp/tq_scale are per-env (num_envs,) domain-randomization + # multipliers (1.0 = off). Passed as args (not closure constants) so + # resampling them every episode does NOT trigger JIT recompilation. @jax.jit - def step_fn(data): + def step_fn(data, fr_scale, dp_scale, tq_scale): + fr = fr_scale[:, None] # broadcast over motor actuators + dp = dp_scale[:, None] + tq = tq_scale[:, None] + # Software limit switch: clamp ctrl once before substeps. pos = data.qpos[:, act_jnt_ids] ctrl = data.ctrl @@ -175,6 +188,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc) gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg) mc = mc * gear_dir / _gear_avg + mc = mc * tq # torque_scale (DR) ctrl = ctrl.at[:, _ctrl_ids].set(mc) data = data.replace(ctrl=ctrl) @@ -184,13 +198,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): vel = d.qvel[:, _qvel_ids] mc = d.ctrl[:, _ctrl_ids] - # Coulomb friction (direction-dependent) - fl = jnp.where(vel > 0, _fl_pos, _fl_neg) + # Coulomb friction (direction-dependent) × DR scale + fl = jnp.where(vel > 0, _fl_pos, _fl_neg) * fr torque = -jnp.where( jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0, ) - # Viscous damping (direction-dependent) - damp = jnp.where(vel > 0, _damp_pos, _damp_neg) + # Viscous damping (direction-dependent) × DR scale + damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp torque = torque - damp * vel # Quadratic velocity drag torque = torque - _visc_quad * vel * jnp.abs(vel) @@ -242,9 +256,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # PyTorch → JAX (zero-copy on GPU via DLPack) actions_jax = jnp.from_dlpack(actions.detach().contiguous()) - # Set ctrl & run N substeps for all environments + # Set ctrl & run N substeps for all environments (with per-env DR scales) self._batch_data = self._batch_data.replace(ctrl=actions_jax) - self._batch_data = self._jit_step(self._batch_data) + self._batch_data = self._jit_step( + self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq, + ) # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32)) @@ -263,6 +279,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): self._batch_data, mask_jax, self._rng, ) + # Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner + # resamples self._dr_scales just before this call, so re-deriving the + # full arrays here keeps the JAX copies current for every env. + self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous()) + self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous()) + self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous()) + # Return only the reset environments' states ids_np = env_ids.cpu().numpy() rq = self._batch_data.qpos[ids_np].astype(jnp.float32)