From 0f13086fee8c746c7842a851a3f4efc35c577be2 Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Mon, 9 Mar 2026 22:47:57 +0100 Subject: [PATCH] :sparkles: remove custom ema and use mujoco motor control --- assets/rotary_cartpole/robot.yaml | 5 +-- configs/runner/mjx.yaml | 1 - configs/runner/mujoco.yaml | 1 - src/core/env.py | 1 - src/core/robot.py | 1 + src/runners/mjx.py | 23 ++++--------- src/runners/mujoco.py | 56 +++++++++++++++++++------------ 7 files changed, 45 insertions(+), 43 deletions(-) diff --git a/assets/rotary_cartpole/robot.yaml b/assets/rotary_cartpole/robot.yaml index 8228c19..25b35d7 100644 --- a/assets/rotary_cartpole/robot.yaml +++ b/assets/rotary_cartpole/robot.yaml @@ -6,9 +6,10 @@ urdf: rotary_cartpole.urdf actuators: - joint: motor_joint type: motor # direct torque control - gear: 0.5 # torque multiplier + gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang) ctrl_range: [-1.0, 1.0] - damping: 0.1 # motor friction / back-EMF + damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget) + filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia) joints: pendulum_joint: diff --git a/configs/runner/mjx.yaml b/configs/runner/mjx.yaml index 891f353..87ccd3f 100644 --- a/configs/runner/mjx.yaml +++ b/configs/runner/mjx.yaml @@ -2,4 +2,3 @@ num_envs: 1024 # MJX shines with many parallel envs device: auto # auto = cuda if available, else cpu dt: 0.002 substeps: 20 -action_ema_alpha: 0.2 diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml index 418f2bd..7d23949 100644 --- a/configs/runner/mujoco.yaml +++ b/configs/runner/mujoco.yaml @@ -2,4 +2,3 @@ num_envs: 64 device: auto # auto = cuda if available, else cpu dt: 0.002 substeps: 20 -action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse) diff --git a/src/core/env.py b/src/core/env.py index 5ec6e48..75d9e81 100644 --- a/src/core/env.py +++ b/src/core/env.py @@ -3,7 +3,6 @@ import dataclasses from typing import TypeVar, Generic, Any from gymnasium import spaces import torch -import pathlib from src.core.robot import RobotConfig, load_robot_config diff --git a/src/core/robot.py b/src/core/robot.py index 73f31e3..1e96cb9 100644 --- a/src/core/robot.py +++ b/src/core/robot.py @@ -35,6 +35,7 @@ class ActuatorConfig: damping: float = 0.05 kp: float = 0.0 # proportional gain (position / velocity actuators) kv: float = 0.0 # derivative gain (position actuators) + filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter @dataclasses.dataclass diff --git a/src/runners/mjx.py b/src/runners/mjx.py index 33c9ab9..11b3ca9 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -40,7 +40,6 @@ class MJXRunnerConfig(BaseRunnerConfig): device: str = "cuda" dt: float = 0.002 substeps: int = 20 - action_ema_alpha: float = 0.2 class MJXRunner(BaseRunner[MJXRunnerConfig]): @@ -86,10 +85,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): self._rng = jax.random.PRNGKey(42) self._batch_data = self._make_batched_data(config.num_envs) - # Step 5: EMA ctrl state (on GPU as JAX array) - self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu)) - - # Step 6: JIT-compile the hot-path functions + # Step 5: JIT-compile the hot-path functions self._compile_jit_fns(config.substeps) # Keep one CPU MjData for offscreen rendering @@ -152,7 +148,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # ── Selective reset ───────────────────────────────────────── @jax.jit - def reset_fn(data, smooth_ctrl, mask, rng): + def reset_fn(data, mask, rng): rng, k1, k2 = jax.random.split(rng, 3) ne = data.qpos.shape[0] @@ -168,12 +164,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): new_qpos = jnp.where(m, default.qpos + pq, data.qpos) new_qvel = jnp.where(m, default.qvel + pv, data.qvel) new_ctrl = jnp.where(m, 0.0, data.ctrl) - new_smooth = jnp.where(m, 0.0, smooth_ctrl) new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl) new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data) - return new_data, new_smooth, rng + return new_data, rng self._jit_reset = reset_fn @@ -183,14 +178,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): # PyTorch → JAX (zero-copy on GPU via DLPack) actions_jax = jnp.from_dlpack(actions.detach().contiguous()) - # EMA smoothing (vectorised on GPU, no Python loop) - alpha = self.config.action_ema_alpha - self._smooth_ctrl = ( - alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl - ) - # Set ctrl & run N substeps for all environments - self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl) + self._batch_data = self._batch_data.replace(ctrl=actions_jax) self._batch_data = self._jit_step(self._batch_data) # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) @@ -206,8 +195,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]): mask[env_ids] = True mask_jax = jnp.from_dlpack(mask) - self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset( - self._batch_data, self._smooth_ctrl, mask_jax, self._rng, + self._batch_data, self._rng = self._jit_reset( + self._batch_data, mask_jax, self._rng, ) # Return only the reset environments' states diff --git a/src/runners/mujoco.py b/src/runners/mujoco.py index bd3f7ff..a555c3b 100644 --- a/src/runners/mujoco.py +++ b/src/runners/mujoco.py @@ -16,7 +16,6 @@ class MuJoCoRunnerConfig(BaseRunnerConfig): device: str = "cpu" dt: float = 0.02 substeps: int = 2 - action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant) @dataclasses.dataclass @@ -110,16 +109,39 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): "gear": str(act.gear), "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", } - if act.type == "position": - attribs["kp"] = str(act.kp) - if act.kv > 0: - attribs["kv"] = str(act.kv) - ET.SubElement(act_elem, "position", attrib=attribs) - elif act.type == "velocity": - attribs["kp"] = str(act.kp) - ET.SubElement(act_elem, "velocity", attrib=attribs) - else: # motor (default) - ET.SubElement(act_elem, "motor", attrib=attribs) + + # Actuator dynamics: 1st-order low-pass filter on ctrl. + # MuJoCo applies this every physics substep, which is + # more physically accurate than an external EMA. + # dyntype is only available on , not on + # shortcut elements like //. + use_general = act.filter_tau > 0 + + if use_general: + attribs["dyntype"] = "filter" + attribs["dynprm"] = str(act.filter_tau) + + if use_general: + attribs["gaintype"] = "fixed" + if act.type == "position": + attribs["biastype"] = "affine" + attribs["gainprm"] = str(act.kp) + attribs["biasprm"] = f"0 -{act.kp} -{act.kv}" + elif act.type == "velocity": + attribs["biastype"] = "affine" + attribs["gainprm"] = str(act.kp) + attribs["biasprm"] = f"0 0 -{act.kp}" + else: # motor + attribs["biastype"] = "none" + ET.SubElement(act_elem, "general", attrib=attribs) + else: + if act.type == "position": + attribs["kp"] = str(act.kp) + if act.kv > 0: + attribs["kv"] = str(act.kv) + elif act.type == "velocity": + attribs["kp"] = str(act.kp) + ET.SubElement(act_elem, act.type, attrib=attribs) # ── Apply joint overrides from robot.yaml ─────────────── # Merge actuator damping + explicit joint overrides @@ -183,24 +205,16 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): self._nq = self._model.nq self._nv = self._model.nv - # Per-env smoothed ctrl state for EMA action filtering. - # Models real motor inertia: ctrl can't reverse instantly. - nu = self._model.nu - self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)] - self._limits = self._extract_actuator_limits(self._model) def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: actions_np: np.ndarray = actions.cpu().numpy() - alpha = self.config.action_ema_alpha qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32) qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32) for i, data in enumerate(self._data): - # EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl - self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i] - data.ctrl[:] = self._smooth_ctrl[i] + data.ctrl[:] = actions_np[i] for _ in range(self.config.substeps): # Software limit switch: zero out ctrl that would push # a joint past its position limit (like a real controller). @@ -245,7 +259,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv) # Reset smoothed ctrl so motor starts from rest - self._smooth_ctrl[env_id][:] = 0.0 + data.ctrl[:] = 0.0 qpos_batch[i] = data.qpos qvel_batch[i] = data.qvel