✨ remove custom ema and use mujoco motor control
This commit is contained in:
@@ -6,9 +6,10 @@ urdf: rotary_cartpole.urdf
|
|||||||
actuators:
|
actuators:
|
||||||
- joint: motor_joint
|
- joint: motor_joint
|
||||||
type: motor # direct torque control
|
type: motor # direct torque control
|
||||||
gear: 0.5 # torque multiplier
|
gear: 0.1 # torque multiplier (was 0.5 — too weak, caused bang-bang)
|
||||||
ctrl_range: [-1.0, 1.0]
|
ctrl_range: [-1.0, 1.0]
|
||||||
damping: 0.1 # motor friction / back-EMF
|
damping: 0.02 # motor friction / back-EMF (was 0.1 — ate torque budget)
|
||||||
|
filter_tau: 0.18 # 1st-order filter ~180ms (models motor inertia)
|
||||||
|
|
||||||
joints:
|
joints:
|
||||||
pendulum_joint:
|
pendulum_joint:
|
||||||
|
|||||||
@@ -2,4 +2,3 @@ num_envs: 1024 # MJX shines with many parallel envs
|
|||||||
device: auto # auto = cuda if available, else cpu
|
device: auto # auto = cuda if available, else cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 20
|
substeps: 20
|
||||||
action_ema_alpha: 0.2
|
|
||||||
|
|||||||
@@ -2,4 +2,3 @@ num_envs: 64
|
|||||||
device: auto # auto = cuda if available, else cpu
|
device: auto # auto = cuda if available, else cpu
|
||||||
dt: 0.002
|
dt: 0.002
|
||||||
substeps: 20
|
substeps: 20
|
||||||
action_ema_alpha: 0.2 # motor smoothing (τ ≈ 179ms, ~5 steps to reverse)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import dataclasses
|
|||||||
from typing import TypeVar, Generic, Any
|
from typing import TypeVar, Generic, Any
|
||||||
from gymnasium import spaces
|
from gymnasium import spaces
|
||||||
import torch
|
import torch
|
||||||
import pathlib
|
|
||||||
|
|
||||||
from src.core.robot import RobotConfig, load_robot_config
|
from src.core.robot import RobotConfig, load_robot_config
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class ActuatorConfig:
|
|||||||
damping: float = 0.05
|
damping: float = 0.05
|
||||||
kp: float = 0.0 # proportional gain (position / velocity actuators)
|
kp: float = 0.0 # proportional gain (position / velocity actuators)
|
||||||
kv: float = 0.0 # derivative gain (position actuators)
|
kv: float = 0.0 # derivative gain (position actuators)
|
||||||
|
filter_tau: float = 0.0 # 1st-order filter time constant (s); 0 = no filter
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class MJXRunnerConfig(BaseRunnerConfig):
|
|||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
dt: float = 0.002
|
dt: float = 0.002
|
||||||
substeps: int = 20
|
substeps: int = 20
|
||||||
action_ema_alpha: float = 0.2
|
|
||||||
|
|
||||||
|
|
||||||
class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||||
@@ -86,10 +85,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
self._rng = jax.random.PRNGKey(42)
|
self._rng = jax.random.PRNGKey(42)
|
||||||
self._batch_data = self._make_batched_data(config.num_envs)
|
self._batch_data = self._make_batched_data(config.num_envs)
|
||||||
|
|
||||||
# Step 5: EMA ctrl state (on GPU as JAX array)
|
# Step 5: JIT-compile the hot-path functions
|
||||||
self._smooth_ctrl = jnp.zeros((config.num_envs, self._nu))
|
|
||||||
|
|
||||||
# Step 6: JIT-compile the hot-path functions
|
|
||||||
self._compile_jit_fns(config.substeps)
|
self._compile_jit_fns(config.substeps)
|
||||||
|
|
||||||
# Keep one CPU MjData for offscreen rendering
|
# Keep one CPU MjData for offscreen rendering
|
||||||
@@ -152,7 +148,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
|
|
||||||
# ── Selective reset ─────────────────────────────────────────
|
# ── Selective reset ─────────────────────────────────────────
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def reset_fn(data, smooth_ctrl, mask, rng):
|
def reset_fn(data, mask, rng):
|
||||||
rng, k1, k2 = jax.random.split(rng, 3)
|
rng, k1, k2 = jax.random.split(rng, 3)
|
||||||
ne = data.qpos.shape[0]
|
ne = data.qpos.shape[0]
|
||||||
|
|
||||||
@@ -168,12 +164,11 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
|
new_qpos = jnp.where(m, default.qpos + pq, data.qpos)
|
||||||
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
|
new_qvel = jnp.where(m, default.qvel + pv, data.qvel)
|
||||||
new_ctrl = jnp.where(m, 0.0, data.ctrl)
|
new_ctrl = jnp.where(m, 0.0, data.ctrl)
|
||||||
new_smooth = jnp.where(m, 0.0, smooth_ctrl)
|
|
||||||
|
|
||||||
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
|
new_data = data.replace(qpos=new_qpos, qvel=new_qvel, ctrl=new_ctrl)
|
||||||
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
|
new_data = jax.vmap(mjx.forward, in_axes=(None, 0))(model, new_data)
|
||||||
|
|
||||||
return new_data, new_smooth, rng
|
return new_data, rng
|
||||||
|
|
||||||
self._jit_reset = reset_fn
|
self._jit_reset = reset_fn
|
||||||
|
|
||||||
@@ -183,14 +178,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
# PyTorch → JAX (zero-copy on GPU via DLPack)
|
||||||
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
actions_jax = jnp.from_dlpack(actions.detach().contiguous())
|
||||||
|
|
||||||
# EMA smoothing (vectorised on GPU, no Python loop)
|
|
||||||
alpha = self.config.action_ema_alpha
|
|
||||||
self._smooth_ctrl = (
|
|
||||||
alpha * actions_jax + (1.0 - alpha) * self._smooth_ctrl
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set ctrl & run N substeps for all environments
|
# Set ctrl & run N substeps for all environments
|
||||||
self._batch_data = self._batch_data.replace(ctrl=self._smooth_ctrl)
|
self._batch_data = self._batch_data.replace(ctrl=actions_jax)
|
||||||
self._batch_data = self._jit_step(self._batch_data)
|
self._batch_data = self._jit_step(self._batch_data)
|
||||||
|
|
||||||
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
# JAX → PyTorch (zero-copy on GPU via DLPack, cast to float32)
|
||||||
@@ -206,8 +195,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
|||||||
mask[env_ids] = True
|
mask[env_ids] = True
|
||||||
mask_jax = jnp.from_dlpack(mask)
|
mask_jax = jnp.from_dlpack(mask)
|
||||||
|
|
||||||
self._batch_data, self._smooth_ctrl, self._rng = self._jit_reset(
|
self._batch_data, self._rng = self._jit_reset(
|
||||||
self._batch_data, self._smooth_ctrl, mask_jax, self._rng,
|
self._batch_data, mask_jax, self._rng,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only the reset environments' states
|
# Return only the reset environments' states
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
|
|||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
dt: float = 0.02
|
dt: float = 0.02
|
||||||
substeps: int = 2
|
substeps: int = 2
|
||||||
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -110,16 +109,39 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
"gear": str(act.gear),
|
"gear": str(act.gear),
|
||||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||||
}
|
}
|
||||||
if act.type == "position":
|
|
||||||
attribs["kp"] = str(act.kp)
|
# Actuator dynamics: 1st-order low-pass filter on ctrl.
|
||||||
if act.kv > 0:
|
# MuJoCo applies this every physics substep, which is
|
||||||
attribs["kv"] = str(act.kv)
|
# more physically accurate than an external EMA.
|
||||||
ET.SubElement(act_elem, "position", attrib=attribs)
|
# dyntype is only available on <general>, not on
|
||||||
elif act.type == "velocity":
|
# shortcut elements like <motor>/<position>/<velocity>.
|
||||||
attribs["kp"] = str(act.kp)
|
use_general = act.filter_tau > 0
|
||||||
ET.SubElement(act_elem, "velocity", attrib=attribs)
|
|
||||||
else: # motor (default)
|
if use_general:
|
||||||
ET.SubElement(act_elem, "motor", attrib=attribs)
|
attribs["dyntype"] = "filter"
|
||||||
|
attribs["dynprm"] = str(act.filter_tau)
|
||||||
|
|
||||||
|
if use_general:
|
||||||
|
attribs["gaintype"] = "fixed"
|
||||||
|
if act.type == "position":
|
||||||
|
attribs["biastype"] = "affine"
|
||||||
|
attribs["gainprm"] = str(act.kp)
|
||||||
|
attribs["biasprm"] = f"0 -{act.kp} -{act.kv}"
|
||||||
|
elif act.type == "velocity":
|
||||||
|
attribs["biastype"] = "affine"
|
||||||
|
attribs["gainprm"] = str(act.kp)
|
||||||
|
attribs["biasprm"] = f"0 0 -{act.kp}"
|
||||||
|
else: # motor
|
||||||
|
attribs["biastype"] = "none"
|
||||||
|
ET.SubElement(act_elem, "general", attrib=attribs)
|
||||||
|
else:
|
||||||
|
if act.type == "position":
|
||||||
|
attribs["kp"] = str(act.kp)
|
||||||
|
if act.kv > 0:
|
||||||
|
attribs["kv"] = str(act.kv)
|
||||||
|
elif act.type == "velocity":
|
||||||
|
attribs["kp"] = str(act.kp)
|
||||||
|
ET.SubElement(act_elem, act.type, attrib=attribs)
|
||||||
|
|
||||||
# ── Apply joint overrides from robot.yaml ───────────────
|
# ── Apply joint overrides from robot.yaml ───────────────
|
||||||
# Merge actuator damping + explicit joint overrides
|
# Merge actuator damping + explicit joint overrides
|
||||||
@@ -183,24 +205,16 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
self._nq = self._model.nq
|
self._nq = self._model.nq
|
||||||
self._nv = self._model.nv
|
self._nv = self._model.nv
|
||||||
|
|
||||||
# Per-env smoothed ctrl state for EMA action filtering.
|
|
||||||
# Models real motor inertia: ctrl can't reverse instantly.
|
|
||||||
nu = self._model.nu
|
|
||||||
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
|
|
||||||
|
|
||||||
self._limits = self._extract_actuator_limits(self._model)
|
self._limits = self._extract_actuator_limits(self._model)
|
||||||
|
|
||||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
actions_np: np.ndarray = actions.cpu().numpy()
|
actions_np: np.ndarray = actions.cpu().numpy()
|
||||||
alpha = self.config.action_ema_alpha
|
|
||||||
|
|
||||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||||
|
|
||||||
for i, data in enumerate(self._data):
|
for i, data in enumerate(self._data):
|
||||||
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
|
data.ctrl[:] = actions_np[i]
|
||||||
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
|
|
||||||
data.ctrl[:] = self._smooth_ctrl[i]
|
|
||||||
for _ in range(self.config.substeps):
|
for _ in range(self.config.substeps):
|
||||||
# Software limit switch: zero out ctrl that would push
|
# Software limit switch: zero out ctrl that would push
|
||||||
# a joint past its position limit (like a real controller).
|
# a joint past its position limit (like a real controller).
|
||||||
@@ -245,7 +259,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||||
|
|
||||||
# Reset smoothed ctrl so motor starts from rest
|
# Reset smoothed ctrl so motor starts from rest
|
||||||
self._smooth_ctrl[env_id][:] = 0.0
|
data.ctrl[:] = 0.0
|
||||||
|
|
||||||
qpos_batch[i] = data.qpos
|
qpos_batch[i] = data.qpos
|
||||||
qvel_batch[i] = data.qvel
|
qvel_batch[i] = data.qvel
|
||||||
|
|||||||
Reference in New Issue
Block a user