✨ remove custom ema and use mujoco motor control
This commit is contained in:
@@ -6,9 +6,10 @@ urdf: rotary_cartpole.urdf
|
||||
actuators:
|
||||
- joint: motor_joint
|
||||
type: motor # direct torque control
|
||||
gear: 0.5 # torque multiplier
|
||||
gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang)
|
||||
ctrl_range: [-1.0, 1.0]
|
||||
damping: 0.1 # motor friction / back-EMF
|
||||
damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget)
|
||||
filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia)
|
||||
|
||||
joints:
|
||||
pendulum_joint:
|
||||
|
||||
@@ -2,4 +2,3 @@ num_envs: 1024 # MJX shines with many parallel envs
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 20
|
||||
action_ema_alpha: 0.2
|
||||
|
||||
@@ -2,4 +2,3 @@ num_envs: 64
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 20
|
||||
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
||||
|
||||
@@ -3,7 +3,6 @@ import dataclasses
|
||||
from typing import TypeVar, Generic, Any
|
||||
from gymnasium import spaces
|
||||
import torch
|
||||
import pathlib
|
||||
|
||||
from src.core.robot import RobotConfig, load_robot_config
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ class ActuatorConfig:
|
||||
damping: float = 0.05
|
||||
kp: float = 0.0 # proportional gain (position / velocity actuators)
|
||||
kv: float = 0.0 # derivative gain (position actuators)
|
||||
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -40,7 +40,6 @@ class MJXRunnerConfig(BaseRunnerConfig):
|
||||
device: str = "cuda"
|
||||
dt: float = 0.002
|
||||
substeps: int = 20
|
||||
action_ema_alpha: float = 0.2
|
||||
|
||||
|
||||
class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
@@ -86,10 +85,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
self._rng = jax.random.PRNGKey(42)
|
||||
self._batch_data = self._make_batched_data(config.num_envs)
|
||||
|
||||
# Step 5: EMA ctrl state (on GPU as JAX array)
|
||||
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
|
||||
|
||||
# Step 6: JIT-compile the hot-path functions
|
||||
# Step 5: JIT-compile the hot-path functions
|
||||
self._compile_jit_fns(config.substeps)
|
||||
|
||||
# Keep one CPU MjData for offscreen rendering
|
||||
@@ -152,7 +148,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
|
||||
# ── Selective reset ─────────────────────────────────────────
|
||||
@jax.jit
|
||||
def reset_fn(data, smooth_ctrl, mask, rng):
|
||||
def reset_fn(data, mask, rng):
|
||||
rng, k1, k2 = jax.random.split(rng, 3)
|
||||
ne = data.qpos.shape[0]
|
||||
|
||||
@@ -168,12 +164,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
|
||||
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
|
||||
new_ctrl = jnp.where(m, 0.0, data.ctrl)
|
||||
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
|
||||
|
||||
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
|
||||
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
|
||||
|
||||
return new_data, new_smooth, rng
|
||||
return new_data, rng
|
||||
|
||||
self._jit_reset = reset_fn
|
||||
|
||||
@@ -183,14 +178,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||
|
||||
# EMA smoothing (vectorised on GPU, no Python loop)
|
||||
alpha = self.config.action_ema_alpha
|
||||
self._smooth_ctrl = (
|
||||
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
|
||||
)
|
||||
|
||||
# Set ctrl & run N substeps for all environments
|
||||
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
|
||||
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||
self._batch_data = self._jit_step(self._batch_data)
|
||||
|
||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||
@@ -206,8 +195,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
mask[env_ids] = True
|
||||
mask_jax = jnp.from_dlpack(mask)
|
||||
|
||||
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
|
||||
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
|
||||
self._batch_data, self._rng = self._jit_reset(
|
||||
self._batch_data, mask_jax, self._rng,
|
||||
)
|
||||
|
||||
# Return only the reset environments' states
|
||||
|
||||
@@ -16,7 +16,6 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -110,16 +109,39 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
}
|
||||
if act.type == "position":
|
||||
attribs["kp"] = str(act.kp)
|
||||
if act.kv > 0:
|
||||
attribs["kv"] = str(act.kv)
|
||||
ET.SubElement(act_elem, "position", attrib=attribs)
|
||||
elif act.type == "velocity":
|
||||
attribs["kp"] = str(act.kp)
|
||||
ET.SubElement(act_elem, "velocity", attrib=attribs)
|
||||
else: # motor (default)
|
||||
ET.SubElement(act_elem, "motor", attrib=attribs)
|
||||
|
||||
# Actuator dynamics: 1st-order low-pass filter on ctrl.
|
||||
# MuJoCo applies this every physics substep, which is
|
||||
# more physically accurate than an external EMA.
|
||||
# dyntype is only available on <general>, not on
|
||||
# shortcut elements like <motor>/<position>/<velocity>.
|
||||
use_general = act.filter_tau > 0
|
||||
|
||||
if use_general:
|
||||
attribs["dyntype"] = "filter"
|
||||
attribs["dynprm"] = str(act.filter_tau)
|
||||
|
||||
if use_general:
|
||||
attribs["gaintype"] = "fixed"
|
||||
if act.type == "position":
|
||||
attribs["biastype"] = "affine"
|
||||
attribs["gainprm"] = str(act.kp)
|
||||
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
|
||||
elif act.type == "velocity":
|
||||
attribs["biastype"] = "affine"
|
||||
attribs["gainprm"] = str(act.kp)
|
||||
attribs["biasprm"] = f"0 0 -{act.kp}"
|
||||
else: # motor
|
||||
attribs["biastype"] = "none"
|
||||
ET.SubElement(act_elem, "general", attrib=attribs)
|
||||
else:
|
||||
if act.type == "position":
|
||||
attribs["kp"] = str(act.kp)
|
||||
if act.kv > 0:
|
||||
attribs["kv"] = str(act.kv)
|
||||
elif act.type == "velocity":
|
||||
attribs["kp"] = str(act.kp)
|
||||
ET.SubElement(act_elem, act.type, attrib=attribs)
|
||||
|
||||
# ── Apply joint overrides from robot.yaml ───────────────
|
||||
# Merge actuator damping + explicit joint overrides
|
||||
@@ -183,24 +205,16 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
|
||||
# Per-env smoothed ctrl state for EMA action filtering.
|
||||
# Models real motor inertia: ctrl can't reverse instantly.
|
||||
nu = self._model.nu
|
||||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
||||
|
||||
self._limits = self._extract_actuator_limits(self._model)
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
alpha = self.config.action_ema_alpha
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
||||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
||||
data.ctrl[:] = self._smooth_ctrl[i]
|
||||
data.ctrl[:] = actions_np[i]
|
||||
for _ in range(self.config.substeps):
|
||||
# Software limit switch: zero out ctrl that would push
|
||||
# a joint past its position limit (like a real controller).
|
||||
@@ -245,7 +259,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
# Reset smoothed ctrl so motor starts from rest
|
||||
self._smooth_ctrl[env_id][:] = 0.0
|
||||
data.ctrl[:] = 0.0
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
Reference in New Issue
Block a user