✨ add limit enforce to mujoco for joints
This commit is contained in:
@@ -125,10 +125,25 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
model = self._mjx_model
|
||||
default = self._default_mjx_data
|
||||
|
||||
# ── Per-actuator limit info (shared helper on MuJoCoRunner) ──
|
||||
lim = MuJoCoRunner._extract_actuator_limits(self._mj_model)
|
||||
act_jnt_ids = jnp.array(lim.jnt_ids)
|
||||
act_limited = jnp.array(lim.limited)
|
||||
act_lo = jnp.array(lim.lo)
|
||||
act_hi = jnp.array(lim.hi)
|
||||
act_gs = jnp.array(lim.gear_sign)
|
||||
|
||||
# ── Batched step (N substeps per call) ──────────────────────
|
||||
@jax.jit
|
||||
def step_fn(data):
|
||||
def body(_, d):
|
||||
# Software limit switch: zero ctrl pushing past joint limits
|
||||
pos = d.qpos[:, act_jnt_ids] # (num_envs, nu)
|
||||
ctrl = d.ctrl
|
||||
at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0)
|
||||
at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0)
|
||||
clamped = jnp.where(at_hi | at_lo, 0.0, ctrl)
|
||||
d = d.replace(ctrl=clamped)
|
||||
return jax.vmap(mjx.step, in_axes=(None, 0))(model, d)
|
||||
|
||||
return jax.lax.fori_loop(0, substeps, body, data)
|
||||
|
||||
@@ -18,6 +18,20 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
substeps: int = 2
|
||||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ActuatorLimits:
|
||||
"""Per-actuator joint-limit info used for software ctrl clamping.
|
||||
|
||||
Real motor controllers have position limits that prevent the motor
|
||||
from driving into the mechanical hard stop.
|
||||
"""
|
||||
jnt_ids: np.ndarray # (nu,) joint index for each actuator
|
||||
limited: np.ndarray # (nu,) bool — whether that joint is limited
|
||||
lo: np.ndarray # (nu,) lower position bound
|
||||
hi: np.ndarray # (nu,) upper position bound
|
||||
gear_sign: np.ndarray # (nu,) sign of gear ratio
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
super().__init__(env, config)
|
||||
@@ -134,8 +148,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("limited") == "true" or jnt.get("range"):
|
||||
jnt.set("solreflimit", "-1000 -100")
|
||||
jnt.set("solimplimit", "0.95 0.99 0.001")
|
||||
jnt.set("solreflimit", "-100000 -10000")
|
||||
jnt.set("solimplimit", "0.99 0.999 0.001")
|
||||
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
@@ -145,6 +159,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
finally:
|
||||
tmp_mjcf.unlink(missing_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _extract_actuator_limits(model: mujoco.MjModel) -> ActuatorLimits:
|
||||
"""Extract per-actuator joint-limit arrays from a loaded MjModel.
|
||||
|
||||
Shared by MuJoCoRunner and MJXRunner so the logic isn't duplicated.
|
||||
"""
|
||||
nu = model.nu
|
||||
jnt_ids = np.array([model.actuator_trnid[i, 0] for i in range(nu)])
|
||||
return ActuatorLimits(
|
||||
jnt_ids=jnt_ids,
|
||||
limited=np.array([model.jnt_limited[j] for j in jnt_ids], dtype=bool),
|
||||
lo=np.array([model.jnt_range[j, 0] for j in jnt_ids]),
|
||||
hi=np.array([model.jnt_range[j, 1] for j in jnt_ids]),
|
||||
gear_sign=np.sign(np.array([model.actuator_gear[i, 0] for i in range(nu)])),
|
||||
)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
self._model = self._load_model(self.env.robot)
|
||||
self._model.opt.timestep = config.dt
|
||||
@@ -158,6 +188,8 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
nu = self._model.nu
|
||||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||||
|
||||
self._limits = self._extract_actuator_limits(self._model)
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
alpha = self.config.action_ema_alpha
|
||||
@@ -170,6 +202,17 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||||
data.ctrl[:] = self._smooth_ctrl[i]
|
||||
for _ in range(self.config.substeps):
|
||||
# Software limit switch: zero out ctrl that would push
|
||||
# a joint past its position limit (like a real controller).
|
||||
lim = self._limits
|
||||
for a in range(self._model.nu):
|
||||
if lim.limited[a]:
|
||||
pos = data.qpos[lim.jnt_ids[a]]
|
||||
gs = lim.gear_sign[a]
|
||||
if pos >= lim.hi[a] and gs * data.ctrl[a] > 0:
|
||||
data.ctrl[a] = 0.0
|
||||
elif pos <= lim.lo[a] and gs * data.ctrl[a] < 0:
|
||||
data.ctrl[a] = 0.0
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
|
||||
Reference in New Issue
Block a user