feat: full DR (friction/damping/torque) in MJX JIT step
This commit is contained in:
@@ -7,11 +7,12 @@ history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
|
||||
rma_mode: "none" # "none" | "teacher" | "deploy"
|
||||
|
||||
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||
# NOTE: action-delay and sensor-noise are applied for MJX, but the
|
||||
# per-env dynamics *scales* (friction/damping/torque) are NOT yet wired
|
||||
# into the JIT step — use runner=mujoco for scale randomization, or keep
|
||||
# this block to delay+noise only on MJX.
|
||||
# Full DR on GPU: latency + sensor noise + per-env dynamics scales
|
||||
# (friction/damping/torque) are all applied inside the JIT step.
|
||||
domain_rand:
|
||||
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
|
||||
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||
|
||||
@@ -103,6 +103,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
# Keep one CPU MjData for offscreen rendering
|
||||
self._render_data = mujoco.MjData(self._mj_model)
|
||||
|
||||
# Per-env DR scale arrays (synced from torch on every reset).
|
||||
# Initialised to 1.0 here because _setup_domain_rand runs after this.
|
||||
self._mjx_fr = jnp.ones(config.num_envs)
|
||||
self._mjx_dp = jnp.ones(config.num_envs)
|
||||
self._mjx_tq = jnp.ones(config.num_envs)
|
||||
|
||||
log.info(
|
||||
"mjx_runner_ready",
|
||||
num_envs=config.num_envs,
|
||||
@@ -159,8 +165,15 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
||||
|
||||
# ── Batched step (N substeps per call) ──────────────────────
|
||||
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
|
||||
# multipliers (1.0 = off). Passed as args (not closure constants) so
|
||||
# resampling them every episode does NOT trigger JIT recompilation.
|
||||
@jax.jit
|
||||
def step_fn(data):
|
||||
def step_fn(data, fr_scale, dp_scale, tq_scale):
|
||||
fr = fr_scale[:, None] # broadcast over motor actuators
|
||||
dp = dp_scale[:, None]
|
||||
tq = tq_scale[:, None]
|
||||
|
||||
# Software limit switch: clamp ctrl once before substeps.
|
||||
pos = data.qpos[:, act_jnt_ids]
|
||||
ctrl = data.ctrl
|
||||
@@ -175,6 +188,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
||||
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
||||
mc = mc * gear_dir / _gear_avg
|
||||
mc = mc * tq # torque_scale (DR)
|
||||
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
||||
|
||||
data = data.replace(ctrl=ctrl)
|
||||
@@ -184,13 +198,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
vel = d.qvel[:, _qvel_ids]
|
||||
mc = d.ctrl[:, _ctrl_ids]
|
||||
|
||||
# Coulomb friction (direction-dependent)
|
||||
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
|
||||
# Coulomb friction (direction-dependent) × DR scale
|
||||
fl = jnp.where(vel > 0, _fl_pos, _fl_neg) * fr
|
||||
torque = -jnp.where(
|
||||
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
||||
)
|
||||
# Viscous damping (direction-dependent)
|
||||
damp = jnp.where(vel > 0, _damp_pos, _damp_neg)
|
||||
# Viscous damping (direction-dependent) × DR scale
|
||||
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
|
||||
torque = torque - damp * vel
|
||||
# Quadratic velocity drag
|
||||
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
||||
@@ -242,9 +256,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||
|
||||
# Set ctrl & run N substeps for all environments
|
||||
# Set ctrl & run N substeps for all environments (with per-env DR scales)
|
||||
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||
self._batch_data = self._jit_step(self._batch_data)
|
||||
self._batch_data = self._jit_step(
|
||||
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
|
||||
)
|
||||
|
||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||
@@ -263,6 +279,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
self._batch_data, mask_jax, self._rng,
|
||||
)
|
||||
|
||||
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
|
||||
# resamples self._dr_scales just before this call, so re-deriving the
|
||||
# full arrays here keeps the JAX copies current for every env.
|
||||
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
|
||||
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
|
||||
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
|
||||
|
||||
# Return only the reset environments' states
|
||||
ids_np = env_ids.cpu().numpy()
|
||||
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
|
||||
|
||||
Reference in New Issue
Block a user