remove custom ema and use mujoco motor control

This commit is contained in:
2026-03-09 22:47:57 +01:00
parent 9813319275
commit 0f13086fee
7 changed files with 45 additions and 43 deletions

View File

@@ -6,9 +6,10 @@ urdf: rotary_cartpole.urdf
actuators:
- joint: motor_joint
type: motor # direct torque control
gear: 0.5 # torque multiplier
gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang)
ctrl_range: [-1.0, 1.0]
damping: 0.1 # motor friction / back-EMF
damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget)
filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia)
joints:
pendulum_joint:

View File

@@ -2,4 +2,3 @@ num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu
dt: 0.002
substeps: 20
action_ema_alpha: 0.2

View File

@@ -2,4 +2,3 @@ num_envs: 64
device: auto # auto = cuda if available, else cpu
dt: 0.002
substeps: 20
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)

View File

@@ -3,7 +3,6 @@ import dataclasses
from typing import TypeVar, Generic, Any
from gymnasium import spaces
import torch
import pathlib
from src.core.robot import RobotConfig, load_robot_config

View File

@@ -35,6 +35,7 @@ class ActuatorConfig:
damping: float = 0.05
kp: float = 0.0 # proportional gain (position / velocity actuators)
kv: float = 0.0 # derivative gain (position actuators)
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
@dataclasses.dataclass

View File

@@ -40,7 +40,6 @@ class MJXRunnerConfig(BaseRunnerConfig):
device: str = "cuda"
dt: float = 0.002
substeps: int = 20
action_ema_alpha: float = 0.2
class MJXRunner(BaseRunner[MJXRunnerConfig]):
@@ -86,10 +85,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
self._rng = jax.random.PRNGKey(42)
self._batch_data = self._make_batched_data(config.num_envs)
# Step 5: EMA ctrl state (on GPU as JAX array)
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
# Step 6: JIT-compile the hot-path functions
# Step 5: JIT-compile the hot-path functions
self._compile_jit_fns(config.substeps)
# Keep one CPU MjData for offscreen rendering
@@ -152,7 +148,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# ── Selective reset ─────────────────────────────────────────
@jax.jit
def reset_fn(data, smooth_ctrl, mask, rng):
def reset_fn(data, mask, rng):
rng, k1, k2 = jax.random.split(rng, 3)
ne = data.qpos.shape[0]
@@ -168,12 +164,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
new_ctrl = jnp.where(m, 0.0, data.ctrl)
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
return new_data, new_smooth, rng
return new_data, rng
self._jit_reset = reset_fn
@@ -183,14 +178,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# PyTorch → JAX (zero-copy on GPU via DLPack)
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
# EMA smoothing (vectorised on GPU, no Python loop)
alpha = self.config.action_ema_alpha
self._smooth_ctrl = (
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
)
# Set ctrl & run N substeps for all environments
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
self._batch_data = self._jit_step(self._batch_data)
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
@@ -206,8 +195,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
mask[env_ids] = True
mask_jax = jnp.from_dlpack(mask)
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
self._batch_data, self._rng = self._jit_reset(
self._batch_data, mask_jax, self._rng,
)
# Return only the reset environments' states

View File

@@ -16,7 +16,6 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
@dataclasses.dataclass
@@ -110,16 +109,39 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
}
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
ET.SubElement(act_elem, "position", attrib=attribs)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, "velocity", attrib=attribs)
else: # motor (default)
ET.SubElement(act_elem, "motor", attrib=attribs)
# Actuator dynamics: 1st-order low-pass filter on ctrl.
# MuJoCo applies this every physics substep, which is
# more physically accurate than an external EMA.
# dyntype is only available on <general>, not on
# shortcut elements like <motor>/<position>/<velocity>.
use_general = act.filter_tau > 0
if use_general:
attribs["dyntype"] = "filter"
attribs["dynprm"] = str(act.filter_tau)
if use_general:
attribs["gaintype"] = "fixed"
if act.type == "position":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
elif act.type == "velocity":
attribs["biastype"] = "affine"
attribs["gainprm"] = str(act.kp)
attribs["biasprm"] = f"0 0 -{act.kp}"
else: # motor
attribs["biastype"] = "none"
ET.SubElement(act_elem, "general", attrib=attribs)
else:
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, act.type, attrib=attribs)
# ── Apply joint overrides from robot.yaml ───────────────
# Merge actuator damping + explicit joint overrides
@@ -183,24 +205,16 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._nq = self._model.nq
self._nv = self._model.nv
# Per-env smoothed ctrl state for EMA action filtering.
# Models real motor inertia: ctrl can't reverse instantly.
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
self._limits = self._extract_actuator_limits(self._model)
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
data.ctrl[:] = actions_np[i]
for _ in range(self.config.substeps):
# Software limit switch: zero out ctrl that would push
# a joint past its position limit (like a real controller).
@@ -245,7 +259,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Reset smoothed ctrl so motor starts from rest
self._smooth_ctrl[env_id][:] = 0.0
data.ctrl[:] = 0.0
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel