diff --git a/src/runners/mjx.py b/src/runners/mjx.py index 3511fd2..33c9ab9 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -125,10 +125,25 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): model = self._mjx_model default = self._default_mjx_data + # ── Per-actuator limit info (shared helper on MuJoCoRunner) ── + lim = MuJoCoRunner._extract_actuator_limits(self._mj_model) + act_jnt_ids = jnp.array(lim.jnt_ids) + act_limited = jnp.array(lim.limited) + act_lo = jnp.array(lim.lo) + act_hi = jnp.array(lim.hi) + act_gs = jnp.array(lim.gear_sign) + # ── Batched step (N substeps per call) ────────────────────── @jax.jit def step_fn(data): def body(_, d): + # Software limit switch: zero ctrl pushing past joint limits + pos = d.qpos[:, act_jnt_ids] # (num_envs, nu) + ctrl = d.ctrl + at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0) + at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0) + clamped = jnp.where(at_hi | at_lo, 0.0, ctrl) + d = d.replace(ctrl=clamped) return jax.vmap(mjx.step, in_axes=(None, 0))(model, d) return jax.lax.fori_loop(0, substeps, body, data) diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index c86b391..bd3f7ff 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -18,6 +18,20 @@ class MuJoCoRunnerConfig(BaseRunnerConfig): substeps: int = 2 action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant) + +@dataclasses.dataclass +class ActuatorLimits: + """Per-actuator joint-limit info used for software ctrl clamping. + + Real motor controllers have position limits that prevent the motor + from driving into the mechanical hard stop. + """ + jnt_ids: np.ndarray # (nu,) joint index for each actuator + limited: np.ndarray # (nu,) bool — whether that joint is limited + lo: np.ndarray # (nu,) lower position bound + hi: np.ndarray # (nu,) upper position bound + gear_sign: np.ndarray # (nu,) sign of gear ratio + class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig): super().__init__(env, config) @@ -134,8 +148,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): for body in root.iter("body"): for jnt in body.findall("joint"): if jnt.get("limited") == "true" or jnt.get("range"): - jnt.set("solreflimit", "-1000 -100") - jnt.set("solimplimit", "0.95 0.99 0.001") + jnt.set("solreflimit", "-100000 -10000") + jnt.set("solimplimit", "0.99 0.999 0.001") # Step 4: Write modified MJCF and reload from file path # (from_xml_path resolves mesh paths relative to the file location) @@ -145,6 +159,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): finally: tmp_mjcf.unlink(missing_ok=True) + @staticmethod + def _extract_actuator_limits(model: mujoco.MjModel) -> ActuatorLimits: + """Extract per-actuator joint-limit arrays from a loaded MjModel. + + Shared by MuJoCoRunner and MJXRunner so the logic isn't duplicated. + """ + nu = model.nu + jnt_ids = np.array([model.actuator_trnid[i, 0] for i in range(nu)]) + return ActuatorLimits( + jnt_ids=jnt_ids, + limited=np.array([model.jnt_limited[j] for j in jnt_ids], dtype=bool), + lo=np.array([model.jnt_range[j, 0] for j in jnt_ids]), + hi=np.array([model.jnt_range[j, 1] for j in jnt_ids]), + gear_sign=np.sign(np.array([model.actuator_gear[i, 0] for i in range(nu)])), + ) + def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: self._model = self._load_model(self.env.robot) self._model.opt.timestep = config.dt @@ -158,6 +188,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): nu = self._model.nu self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)] + self._limits = self._extract_actuator_limits(self._model) + def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: actions_np: np.ndarray = actions.cpu().numpy() alpha = self.config.action_ema_alpha @@ -170,6 +202,17 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i] data.ctrl[:] = self._smooth_ctrl[i] for _ in range(self.config.substeps): + # Software limit switch: zero out ctrl that would push + # a joint past its position limit (like a real controller). + lim = self._limits + for a in range(self._model.nu): + if lim.limited[a]: + pos = data.qpos[lim.jnt_ids[a]] + gs = lim.gear_sign[a] + if pos >= lim.hi[a] and gs * data.ctrl[a] > 0: + data.ctrl[a] = 0.0 + elif pos <= lim.lo[a] and gs * data.ctrl[a] < 0: + data.ctrl[a] = 0.0 mujoco.mj_step(self._model, data) qpos_batch[i] = data.qpos