remove custom ema and use mujoco motor control

This commit is contained in:
2026-03-09 22:47:57 +01:00
parent 9813319275
commit 0f13086fee
7 changed files with 45 additions and 43 deletions

View File

@@ -6,9 +6,10 @@ urdf: rotary_cartpole.urdf
actuators: actuators:
- joint: motor_joint - joint: motor_joint
type: motor # direct torque control type: motor # direct torque control
gear: 0.5 # torque multiplier gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang)
ctrl_range: [-1.0, 1.0] ctrl_range: [-1.0, 1.0]
damping: 0.1 # motor friction / back-EMF damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget)
filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia)
joints: joints:
pendulum_joint: pendulum_joint:

View File

@@ -2,4 +2,3 @@ num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 20 substeps: 20
action_ema_alpha: 0.2

View File

@@ -2,4 +2,3 @@ num_envs: 64
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 20 substeps: 20
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)

View File

@@ -3,7 +3,6 @@ import dataclasses
from typing import TypeVar, Generic, Any from typing import TypeVar, Generic, Any
from gymnasium import spaces from gymnasium import spaces
import torch import torch
import pathlib
from src.core.robot import RobotConfig, load_robot_config from src.core.robot import RobotConfig, load_robot_config

View File

@@ -35,6 +35,7 @@ class ActuatorConfig:
damping: float = 0.05 damping: float = 0.05
kp: float = 0.0 # proportional gain (position / velocity actuators) kp: float = 0.0 # proportional gain (position / velocity actuators)
kv: float = 0.0 # derivative gain (position actuators) kv: float = 0.0 # derivative gain (position actuators)
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
@dataclasses.dataclass @dataclasses.dataclass

View File

@@ -40,7 +40,6 @@ class MJXRunnerConfig(BaseRunnerConfig):
device: str = "cuda" device: str = "cuda"
dt: float = 0.002 dt: float = 0.002
substeps: int = 20 substeps: int = 20
action_ema_alpha: float = 0.2
class MJXRunner(BaseRunner[MJXRunnerConfig]): class MJXRunner(BaseRunner[MJXRunnerConfig]):
@@ -86,10 +85,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._rng = jax.random.PRNGKey(42) self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs) self._batch_data = self._make_batched_data(config.num_envs)
# Step 5: EMA ctrl state (on GPU as JAX array) # Step 5: JIT-compile the hot-path functions
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
# Step 6: JIT-compile the hot-path functions
self._compile_jit_fns(config.substeps) self._compile_jit_fns(config.substeps)
# Keep one CPU MjData for offscreen rendering # Keep one CPU MjData for offscreen rendering
@@ -152,7 +148,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# ── Selective reset ───────────────────────────────────────── # ── Selective reset ─────────────────────────────────────────
@jax.jit @jax.jit
def reset_fn(data, smooth_ctrl, mask, rng): def reset_fn(data, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3) rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0] ne = data.qpos.shape[0]
@@ -168,12 +164,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
new_qpos = jnp.where(m, default.qpos + pq, data.qpos) new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
new_qvel = jnp.where(m, default.qvel + pv, data.qvel) new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
new_ctrl = jnp.where(m, 0.0, data.ctrl) new_ctrl = jnp.where(m, 0.0, data.ctrl)
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl) new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data) new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
return new_data, new_smooth, rng return new_data, rng
self._jit_reset = reset_fn self._jit_reset = reset_fn
@@ -183,14 +178,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# PyTorch → JAX (zero-copy on GPU via DLPack) # PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous()) actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# EMA smoothing (vectorised on GPU, no Python loop)
alpha = self.config.action_ema_alpha
self._smooth_ctrl = (
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
)
# Set ctrl & run N substeps for all environments # Set ctrl & run N substeps for all environments
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl) self._batch_data = self._batch_data.replace(ctrl=actions_jax)
self._batch_data = self._jit_step(self._batch_data) self._batch_data = self._jit_step(self._batch_data)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32) # JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
@@ -206,8 +195,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
mask[env_ids] = True mask[env_ids] = True
mask_jax = jnp.from_dlpack(mask) mask_jax = jnp.from_dlpack(mask)
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset( self._batch_data, self._rng = self._jit_reset(
self._batch_data, self._smooth_ctrl, mask_jax, self._rng, self._batch_data, mask_jax, self._rng,
) )
# Return only the reset environments' states # Return only the reset environments' states

View File

@@ -16,7 +16,6 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
device: str = "cpu" device: str = "cpu"
dt: float = 0.02 dt: float = 0.02
substeps: int = 2 substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
@dataclasses.dataclass @dataclasses.dataclass
@@ -110,16 +109,39 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
"gear": str(act.gear), "gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
} }
if act.type == "position":
attribs["kp"] = str(act.kp) # Actuator dynamics: 1st-order low-pass filter on ctrl.
if act.kv > 0: # MuJoCo applies this every physics substep, which is
attribs["kv"] = str(act.kv) # more physically accurate than an external EMA.
ET.SubElement(act_elem, "position", attrib=attribs) # dyntype is only available on <general>, not on
elif act.type == "velocity": # shortcut elements like <motor>/<position>/<velocity>.
attribs["kp"] = str(act.kp) use_general = act.filter_tau > 0
ET.SubElement(act_elem, "velocity", attrib=attribs)
else: # motor (default) if use_general:
ET.SubElement(act_elem, "motor", attrib=attribs) attribs["dyntype"] = "filter"
attribs["dynprm"] = str(act.filter_tau)
if use_general:
attribs["gaintype"] = "fixed"
if act.type == "position":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
elif act.type == "velocity":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 0 -{act.kp}"
else: # motor
attribs["biastype"] = "none"
ET.SubElement(act_elem, "general", attrib=attribs)
else:
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, act.type, attrib=attribs)
# ── Apply joint overrides from robot.yaml ─────────────── # ── Apply joint overrides from robot.yaml ───────────────
# Merge actuator damping + explicit joint overrides # Merge actuator damping + explicit joint overrides
@@ -183,24 +205,16 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._nq = self._model.nq self._nq = self._model.nq
self._nv = self._model.nv self._nv = self._model.nv
# Per-env smoothed ctrl state for EMA action filtering.
# Models real motor inertia: ctrl can't reverse instantly.
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
self._limits = self._extract_actuator_limits(self._model) self._limits = self._extract_actuator_limits(self._model)
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy() actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32) qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32) qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data): for i, data in enumerate(self._data):
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl data.ctrl[:] = actions_np[i]
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps): for _ in range(self.config.substeps):
# Software limit switch: zero out ctrl that would push # Software limit switch: zero out ctrl that would push
# a joint past its position limit (like a real controller). # a joint past its position limit (like a real controller).
@@ -245,7 +259,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv) data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Reset smoothed ctrl so motor starts from rest # Reset smoothed ctrl so motor starts from rest
self._smooth_ctrl[env_id][:] = 0.0 data.ctrl[:] = 0.0
qpos_batch[i] = data.qpos qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel qvel_batch[i] = data.qvel