feat: full DR (friction/damping/torque) in MJX JIT step
This commit is contained in:
@@ -7,11 +7,12 @@ history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
|
|||||||
rma_mode: "none" # "none" | "teacher" | "deploy"
|
rma_mode: "none" # "none" | "teacher" | "deploy"
|
||||||
|
|
||||||
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
# ── Domain randomization (sim-to-real) ──────────────────────────────
|
||||||
# NOTE: action-delay and sensor-noise are applied for MJX, but the
|
# Full DR on GPU: latency + sensor noise + per-env dynamics scales
|
||||||
# per-env dynamics *scales* (friction/damping/torque) are NOT yet wired
|
# (friction/damping/torque) are all applied inside the JIT step.
|
||||||
# into the JIT step — use runner=mujoco for scale randomization, or keep
|
|
||||||
# this block to delay+noise only on MJX.
|
|
||||||
domain_rand:
|
domain_rand:
|
||||||
qpos_noise_std: 0.01 # rad — encoder angle noise
|
qpos_noise_std: 0.01 # rad — encoder angle noise
|
||||||
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
qvel_noise_std: 0.5 # rad/s — velocity-estimate noise (measured)
|
||||||
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
action_delay_steps: [0, 2] # control-step latency (0–40 ms)
|
||||||
|
friction_scale: [0.6, 1.6] # Coulomb-friction multiplier (per env)
|
||||||
|
damping_scale: [0.6, 1.6] # viscous-damping multiplier
|
||||||
|
torque_scale: [0.85, 1.15] # motor-constant / battery-voltage variation
|
||||||
|
|||||||
@@ -103,6 +103,12 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
# Keep one CPU MjData for offscreen rendering
|
# Keep one CPU MjData for offscreen rendering
|
||||||
self._render_data = mujoco.MjData(self._mj_model)
|
self._render_data = mujoco.MjData(self._mj_model)
|
||||||
|
|
||||||
|
# Per-env DR scale arrays (synced from torch on every reset).
|
||||||
|
# Initialised to 1.0 here because _setup_domain_rand runs after this.
|
||||||
|
self._mjx_fr = jnp.ones(config.num_envs)
|
||||||
|
self._mjx_dp = jnp.ones(config.num_envs)
|
||||||
|
self._mjx_tq = jnp.ones(config.num_envs)
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"mjx_runner_ready",
|
"mjx_runner_ready",
|
||||||
num_envs=config.num_envs,
|
num_envs=config.num_envs,
|
||||||
@@ -159,8 +165,15 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
_back_emf = jnp.array([a.back_emf_gain for a in acts])
|
||||||
|
|
||||||
# ── Batched step (N substeps per call) ──────────────────────
|
# ── Batched step (N substeps per call) ──────────────────────
|
||||||
|
# fr/dp/tq_scale are per-env (num_envs,) domain-randomization
|
||||||
|
# multipliers (1.0 = off). Passed as args (not closure constants) so
|
||||||
|
# resampling them every episode does NOT trigger JIT recompilation.
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def step_fn(data):
|
def step_fn(data, fr_scale, dp_scale, tq_scale):
|
||||||
|
fr = fr_scale[:, None] # broadcast over motor actuators
|
||||||
|
dp = dp_scale[:, None]
|
||||||
|
tq = tq_scale[:, None]
|
||||||
|
|
||||||
# Software limit switch: clamp ctrl once before substeps.
|
# Software limit switch: clamp ctrl once before substeps.
|
||||||
pos = data.qpos[:, act_jnt_ids]
|
pos = data.qpos[:, act_jnt_ids]
|
||||||
ctrl = data.ctrl
|
ctrl = data.ctrl
|
||||||
@@ -175,6 +188,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
mc = jnp.where((mc < 0) & (mc > -_dz_neg), 0.0, mc)
|
||||||
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
gear_dir = jnp.where(mc >= 0, _gear_pos, _gear_neg)
|
||||||
mc = mc * gear_dir / _gear_avg
|
mc = mc * gear_dir / _gear_avg
|
||||||
|
mc = mc * tq # torque_scale (DR)
|
||||||
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
ctrl = ctrl.at[:, _ctrl_ids].set(mc)
|
||||||
|
|
||||||
data = data.replace(ctrl=ctrl)
|
data = data.replace(ctrl=ctrl)
|
||||||
@@ -184,13 +198,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
vel = d.qvel[:, _qvel_ids]
|
vel = d.qvel[:, _qvel_ids]
|
||||||
mc = d.ctrl[:, _ctrl_ids]
|
mc = d.ctrl[:, _ctrl_ids]
|
||||||
|
|
||||||
# Coulomb friction (direction-dependent)
|
# Coulomb friction (direction-dependent) × DR scale
|
||||||
fl = jnp.where(vel > 0, _fl_pos, _fl_neg)
|
fl = jnp.where(vel > 0, _fl_pos, _fl_neg) * fr
|
||||||
torque = -jnp.where(
|
torque = -jnp.where(
|
||||||
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
jnp.abs(vel) > 1e-6, jnp.sign(vel) * fl, 0.0,
|
||||||
)
|
)
|
||||||
# Viscous damping (direction-dependent)
|
# Viscous damping (direction-dependent) × DR scale
|
||||||
damp = jnp.where(vel > 0, _damp_pos, _damp_neg)
|
damp = jnp.where(vel > 0, _damp_pos, _damp_neg) * dp
|
||||||
torque = torque - damp * vel
|
torque = torque - damp * vel
|
||||||
# Quadratic velocity drag
|
# Quadratic velocity drag
|
||||||
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
torque = torque - _visc_quad * vel * jnp.abs(vel)
|
||||||
@@ -242,9 +256,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||||
|
|
||||||
# Set ctrl & run N substeps for all environments
|
# Set ctrl & run N substeps for all environments (with per-env DR scales)
|
||||||
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||||
self._batch_data = self._jit_step(self._batch_data)
|
self._batch_data = self._jit_step(
|
||||||
|
self._batch_data, self._mjx_fr, self._mjx_dp, self._mjx_tq,
|
||||||
|
)
|
||||||
|
|
||||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||||
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
qpos = torch.from_dlpack(self._batch_data.qpos.astype(jnp.float32))
|
||||||
@@ -263,6 +279,13 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
self._batch_data, mask_jax, self._rng,
|
self._batch_data, mask_jax, self._rng,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sync per-env DR scales (torch → JAX) for the step fn. BaseRunner
|
||||||
|
# resamples self._dr_scales just before this call, so re-deriving the
|
||||||
|
# full arrays here keeps the JAX copies current for every env.
|
||||||
|
self._mjx_fr = jnp.from_dlpack(self._dr_scales["friction_scale"].contiguous())
|
||||||
|
self._mjx_dp = jnp.from_dlpack(self._dr_scales["damping_scale"].contiguous())
|
||||||
|
self._mjx_tq = jnp.from_dlpack(self._dr_scales["torque_scale"].contiguous())
|
||||||
|
|
||||||
# Return only the reset environments' states
|
# Return only the reset environments' states
|
||||||
ids_np = env_ids.cpu().numpy()
|
ids_np = env_ids.cpu().numpy()
|
||||||
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
|
rq = self._batch_data.qpos[ids_np].astype(jnp.float32)
|
||||||
|
|||||||
Reference in New Issue
Block a user