add limit enforce to mujoco for joints

This commit is contained in:
2026-03-09 22:30:48 +01:00
parent 70cd2cdd7d
commit 9813319275
2 changed files with 60 additions and 2 deletions

View File

@@ -125,10 +125,25 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
model = self._mjx_model
default = self._default_mjx_data
# ── Per-actuator limit info (shared helper on MuJoCoRunner) ──
lim = MuJoCoRunner._extract_actuator_limits(self._mj_model)
act_jnt_ids = jnp.array(lim.jnt_ids)
act_limited = jnp.array(lim.limited)
act_lo = jnp.array(lim.lo)
act_hi = jnp.array(lim.hi)
act_gs = jnp.array(lim.gear_sign)
# ── Batched step (N substeps per call) ──────────────────────
@jax.jit
def step_fn(data):
def body(_, d):
# Software limit switch: zero ctrl pushing past joint limits
pos = d.qpos[:, act_jnt_ids] # (num_envs, nu)
ctrl = d.ctrl
at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0)
at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0)
clamped = jnp.where(at_hi | at_lo, 0.0, ctrl)
d = d.replace(ctrl=clamped)
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
return jax.lax.fori_loop(0, substeps, body, data)

View File

@@ -18,6 +18,20 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
@dataclasses.dataclass
class ActuatorLimits:
"""Per-actuator joint-limit info used for software ctrl clamping.
Real motor controllers have position limits that prevent the motor
from driving into the mechanical hard stop.
"""
jnt_ids: np.ndarray # (nu,) joint index for each actuator
limited: np.ndarray # (nu,) bool — whether that joint is limited
lo: np.ndarray # (nu,) lower position bound
hi: np.ndarray # (nu,) upper position bound
gear_sign: np.ndarray # (nu,) sign of gear ratio
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
super().__init__(env, config)
@@ -134,8 +148,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("limited") == "true" or jnt.get("range"):
jnt.set("solreflimit", "-1000 -100")
jnt.set("solimplimit", "0.95 0.99 0.001")
jnt.set("solreflimit", "-100000 -10000")
jnt.set("solimplimit", "0.99 0.999 0.001")
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
@@ -145,6 +159,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
finally:
tmp_mjcf.unlink(missing_ok=True)
@staticmethod
def _extract_actuator_limits(model: mujoco.MjModel) -> ActuatorLimits:
"""Extract per-actuator joint-limit arrays from a loaded MjModel.
Shared by MuJoCoRunner and MJXRunner so the logic isn't duplicated.
"""
nu = model.nu
jnt_ids = np.array([model.actuator_trnid[i, 0] for i in range(nu)])
return ActuatorLimits(
jnt_ids=jnt_ids,
limited=np.array([model.jnt_limited[j] for j in jnt_ids], dtype=bool),
lo=np.array([model.jnt_range[j, 0] for j in jnt_ids]),
hi=np.array([model.jnt_range[j, 1] for j in jnt_ids]),
gear_sign=np.sign(np.array([model.actuator_gear[i, 0] for i in range(nu)])),
)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
self._model = self._load_model(self.env.robot)
self._model.opt.timestep = config.dt
@@ -158,6 +188,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
self._limits = self._extract_actuator_limits(self._model)
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
@@ -170,6 +202,17 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps):
# Software limit switch: zero out ctrl that would push
# a joint past its position limit (like a real controller).
lim = self._limits
for a in range(self._model.nu):
if lim.limited[a]:
pos = data.qpos[lim.jnt_ids[a]]
gs = lim.gear_sign[a]
if pos >= lim.hi[a] and gs * data.ctrl[a] > 0:
data.ctrl[a] = 0.0
elif pos <= lim.lo[a] and gs * data.ctrl[a] < 0:
data.ctrl[a] = 0.0
mujoco.mj_step(self._model, data)
qpos_batch[i] = data.qpos