add rotary cartpole env

This commit is contained in:
2026-03-08 22:58:32 +01:00
parent c8f28ffbcc
commit c753c369b4
15 changed files with 464 additions and 171 deletions

View File

@@ -1,12 +1,11 @@
import dataclasses
import tempfile
import os
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
import mujoco.viewer
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
@@ -14,6 +13,7 @@ class MuJoCoRunnerConfig(BaseRunnerConfig):
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant)
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
@@ -39,36 +39,89 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
# Step 1: Load URDF/MJCF as-is (no actuators)
model_raw = mujoco.MjModel.from_xml_path(model_path)
abs_path = os.path.abspath(model_path)
model_dir = os.path.dirname(abs_path)
is_urdf = abs_path.lower().endswith(".urdf")
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
# so we inject a <mujoco><compiler meshdir="..."/> block into a
# temporary copy. The original URDF stays clean and simulator-agnostic.
if is_urdf:
tree = ET.parse(abs_path)
root = tree.getroot()
# Detect the mesh subdirectory from the first mesh filename
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
dirname = os.path.dirname(fn)
if dirname:
meshdir = dirname
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
ET.SubElement(mj_ext, "compiler", attrib={
"meshdir": meshdir,
"balanceinertia": "true",
})
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
try:
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
finally:
os.unlink(tmp_urdf)
else:
model_raw = mujoco.MjModel.from_xml_path(abs_path)
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation
tmp_mjcf = tempfile.mktemp(suffix=".xml")
# Step 2: Export internal MJCF representation (save next to original
# model so relative mesh/asset paths resolve correctly on reload)
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
# at steady state, vel_max = torque / damping.
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Add damping to actuated joints to limit max speed and
# mimic real motor friction / back-EMF.
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
actuated_joints = {a.joint for a in actuators}
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("name") in actuated_joints:
jnt.set("damping", "0.05")
# Disable self-collision on all geoms.
# URDF mesh convex hulls often overlap at joints (especially
# grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude),
# causing phantom contact forces.
for geom in root.iter("geom"):
geom.set("contype", "0")
geom.set("conaffinity", "0")
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
with open(tmp_mjcf, "w") as f:
f.write(modified_xml)
return mujoco.MjModel.from_xml_path(tmp_mjcf)
finally:
import os
os.unlink(tmp_mjcf)
# Step 3: Inject actuators into the MJCF XML
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Step 4: Reload from modified MJCF
modified_xml = ET.tostring(root, encoding="unicode")
return mujoco.MjModel.from_xml_string(modified_xml)
if os.path.exists(tmp_mjcf):
os.unlink(tmp_mjcf)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
@@ -83,14 +136,22 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
self._nq = self._model.nq
self._nv = self._model.nv
# Per-env smoothed ctrl state for EMA action filtering.
# Models real motor inertia: ctrl can't reverse instantly.
nu = self._model.nu
self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)]
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
alpha = self.config.action_ema_alpha
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
data.ctrl[:] = actions_np[i]
# EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl
self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i]
data.ctrl[:] = self._smooth_ctrl[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
@@ -109,14 +170,23 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
default_qpos = self.env.get_default_qpos(self._nq)
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Add small random perturbation so the pole doesn't start perfectly upright
# Set initial pose (env-specific, e.g. pendulum hanging down)
if default_qpos is not None:
data.qpos[:] = default_qpos
# Add small random perturbation for exploration
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
# Reset smoothed ctrl so motor starts from rest
self._smooth_ctrl[env_id][:] = 0.0
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
@@ -126,30 +196,14 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
)
def _sim_close(self) -> None:
if hasattr(self, "_viewer") and self._viewer is not None:
self._viewer.close()
self._viewer = None
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
if mode == "human":
if not hasattr(self, "_viewer") or self._viewer is None:
self._viewer = mujoco.viewer.launch_passive(
self._model, self._data[env_idx]
)
# Update visual geometry from current physics state
mujoco.mj_forward(self._model, self._data[env_idx])
self._viewer.sync()
return None
elif mode == "rgb_array":
# Cache the offscreen renderer to avoid create/destroy overhead
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
return torch.from_numpy(pixels)
def render(self, env_idx: int = 0) -> np.ndarray | None:
"""Offscreen render → RGB numpy array (H, W, 3)."""
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
return self._offscreen_renderer.render().copy()