diff --git a/assets/rotary_cartpole/robot.yaml b/assets/rotary_cartpole/robot.yaml index 25b35d7..d18686a 100644 --- a/assets/rotary_cartpole/robot.yaml +++ b/assets/rotary_cartpole/robot.yaml @@ -6,11 +6,14 @@ urdf: rotary_cartpole.urdf actuators: - joint: motor_joint type: motor # direct torque control - gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang) + gear: 0.064 # stall torque @ 58.8% PWM: 0.108 × 150/255 = 0.064 N·m ctrl_range: [-1.0, 1.0] - damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget) - filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia) + damping: 0.003 # viscous back-EMF only (small) + filter_tau: 0.03 # mechanical time constant ~30ms (37mm gearmotor) joints: + motor_joint: + armature: 0.0001 # reflected rotor inertia: ~1e-7 × 30² = 9e-5 kg·m² + frictionloss: 0.03 # disabled — may slow MJX (constraint-based solver) pendulum_joint: damping: 0.0001 # bearing friction diff --git a/src/core/robot.py b/src/core/robot.py index 1e96cb9..e8ef93a 100644 --- a/src/core/robot.py +++ b/src/core/robot.py @@ -42,6 +42,8 @@ class ActuatorConfig: class JointConfig: """Per-joint overrides applied on top of the URDF values.""" damping: float | None = None + armature: float | None = None # reflected rotor inertia (kg·m²) + frictionloss: float | None = None # Coulomb/dry friction torque (N·m) @dataclasses.dataclass diff --git a/src/runners/mjx.py b/src/runners/mjx.py index 11b3ca9..bf2dcce 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -132,14 +132,17 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # ── Batched step (N substeps per call) ────────────────────── @jax.jit def step_fn(data): + # Software limit switch: clamp ctrl once before substeps. + # Armature + default joint limits prevent significant overshoot + # within the substep window, so per-substep clamping isn't needed. + pos = data.qpos[:, act_jnt_ids] + ctrl = data.ctrl + at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0) + at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0) + clamped = jnp.where(at_hi | at_lo, 0.0, ctrl) + data = data.replace(ctrl=clamped) + def body(_, d): - # Software limit switch: zero ctrl pushing past joint limits - pos = d.qpos[:, act_jnt_ids] # (num_envs, nu) - ctrl = d.ctrl - at_hi = act_limited & (pos >= act_hi) & (act_gs * ctrl > 0) - at_lo = act_limited & (pos <= act_lo) & (act_gs * ctrl < 0) - clamped = jnp.where(at_hi | at_lo, 0.0, ctrl) - d = d.replace(ctrl=clamped) return jax.vmap(mjx.step, in_axes=(None, 0))(model, d) return jax.lax.fori_loop(0, substeps, body, data) diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index a555c3b..bc395c0 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -146,15 +146,25 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): # ── Apply joint overrides from robot.yaml ─────────────── # Merge actuator damping + explicit joint overrides joint_damping = {a.joint: a.damping for a in robot.actuators} + joint_armature: dict[str, float] = {} + joint_frictionloss: dict[str, float] = {} for name, jcfg in robot.joints.items(): if jcfg.damping is not None: joint_damping[name] = jcfg.damping + if jcfg.armature is not None: + joint_armature[name] = jcfg.armature + if jcfg.frictionloss is not None: + joint_frictionloss[name] = jcfg.frictionloss for body in root.iter("body"): for jnt in body.findall("joint"): name = jnt.get("name") if name in joint_damping: jnt.set("damping", str(joint_damping[name])) + if name in joint_armature: + jnt.set("armature", str(joint_armature[name])) + if name in joint_frictionloss: + jnt.set("frictionloss", str(joint_frictionloss[name])) # Disable self-collision on all geoms. # URDF mesh convex hulls often overlap at joints (especially @@ -164,14 +174,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): geom.set("contype", "0") geom.set("conaffinity", "0") - # Harden joint limits: MuJoCo's default soft limits are too - # weak and allow overshoot. Negative solref = hard constraint - # (direct stiffness/damping instead of impedance match). - for body in root.iter("body"): - for jnt in body.findall("joint"): - if jnt.get("limited") == "true" or jnt.get("range"): - jnt.set("solreflimit", "-100000 -10000") - jnt.set("solimplimit", "0.99 0.999 0.001") + # Joint limits use MuJoCo's default solver settings. + # The software limit switch (zeroing ctrl at limits) + armature + # prevent overshoot without needing ultra-stiff solref that + # kills MJX GPU solver performance. # Step 4: Write modified MJCF and reload from file path # (from_xml_path resolves mesh paths relative to the file location)