import dataclasses import xml.etree.ElementTree as ET from pathlib import Path import mujoco import numpy as np import torch from src.core.env import BaseEnv from src.core.robot import RobotConfig from src.core.runner import BaseRunner, BaseRunnerConfig @dataclasses.dataclass class MuJoCoRunnerConfig(BaseRunnerConfig): num_envs: int = 16 device: str = "cpu" dt: float = 0.02 substeps: int = 2 action_ema_alpha: float = 0.2 # EMA smoothing on ctrl (0=frozen, 1=instant) class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]): def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig): super().__init__(env, config) @property def num_envs(self) -> int: return self.config.num_envs @property def device(self) -> torch.device: return torch.device(self.config.device) @staticmethod def _load_model(robot: RobotConfig) -> mujoco.MjModel: """Load a URDF (or MJCF) and apply robot.yaml settings. Two-step approach required because MuJoCo's URDF parser ignores in the extension block: 1. Load the URDF → MuJoCo converts it to internal MJCF 2. Export the MJCF XML, inject actuators + joint overrides, reload This keeps the URDF clean (re-exportable from CAD) — all hardware tuning lives in robot.yaml. """ abs_path = robot.urdf_path.resolve() model_dir = abs_path.parent is_urdf = abs_path.suffix.lower() == ".urdf" # MuJoCo's URDF parser strips directory prefixes from mesh filenames, # so we inject a block into a # temporary copy. The original URDF stays clean and simulator-agnostic. if is_urdf: tree = ET.parse(abs_path) root = tree.getroot() # Detect the mesh subdirectory from the first mesh filename meshdir = None for mesh_el in root.iter("mesh"): fn = mesh_el.get("filename", "") parent = str(Path(fn).parent) if parent and parent != ".": meshdir = parent break if meshdir: mj_ext = ET.SubElement(root, "mujoco") ET.SubElement(mj_ext, "compiler", attrib={ "meshdir": meshdir, "balanceinertia": "true", }) tmp_urdf = model_dir / "_tmp_mujoco_load.urdf" tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode") try: model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf)) finally: tmp_urdf.unlink() else: model_raw = mujoco.MjModel.from_xml_path(str(abs_path)) if not robot.actuators and not robot.joints: return model_raw # Step 2: Export internal MJCF, inject actuators + joint overrides, reload tmp_mjcf = model_dir / "_tmp_actuator_inject.xml" try: mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw) mjcf_str = tmp_mjcf.read_text() root = ET.fromstring(mjcf_str) # ── Inject actuators ──────────────────────────────────── if robot.actuators: act_elem = ET.SubElement(root, "actuator") for act in robot.actuators: attribs = { "name": f"{act.joint}_{act.type}", "joint": act.joint, "gear": str(act.gear), "ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}", } if act.type == "position": attribs["kp"] = str(act.kp) if act.kv > 0: attribs["kv"] = str(act.kv) ET.SubElement(act_elem, "position", attrib=attribs) elif act.type == "velocity": attribs["kp"] = str(act.kp) ET.SubElement(act_elem, "velocity", attrib=attribs) else: # motor (default) ET.SubElement(act_elem, "motor", attrib=attribs) # ── Apply joint overrides from robot.yaml ─────────────── # Merge actuator damping + explicit joint overrides joint_damping = {a.joint: a.damping for a in robot.actuators} for name, jcfg in robot.joints.items(): if jcfg.damping is not None: joint_damping[name] = jcfg.damping for body in root.iter("body"): for jnt in body.findall("joint"): name = jnt.get("name") if name in joint_damping: jnt.set("damping", str(joint_damping[name])) # Disable self-collision on all geoms. # URDF mesh convex hulls often overlap at joints (especially # grandparent↔grandchild bodies that MuJoCo does NOT auto-exclude), # causing phantom contact forces. for geom in root.iter("geom"): geom.set("contype", "0") geom.set("conaffinity", "0") # Harden joint limits: MuJoCo's default soft limits are too # weak and allow overshoot. Negative solref = hard constraint # (direct stiffness/damping instead of impedance match). for body in root.iter("body"): for jnt in body.findall("joint"): if jnt.get("limited") == "true" or jnt.get("range"): jnt.set("solreflimit", "-1000 -100") jnt.set("solimplimit", "0.95 0.99 0.001") # Step 4: Write modified MJCF and reload from file path # (from_xml_path resolves mesh paths relative to the file location) modified_xml = ET.tostring(root, encoding="unicode") tmp_mjcf.write_text(modified_xml) return mujoco.MjModel.from_xml_path(str(tmp_mjcf)) finally: tmp_mjcf.unlink(missing_ok=True) def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: self._model = self._load_model(self.env.robot) self._model.opt.timestep = config.dt self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)] self._nq = self._model.nq self._nv = self._model.nv # Per-env smoothed ctrl state for EMA action filtering. # Models real motor inertia: ctrl can't reverse instantly. nu = self._model.nu self._smooth_ctrl = [np.zeros(nu, dtype=np.float64) for _ in range(config.num_envs)] def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: actions_np: np.ndarray = actions.cpu().numpy() alpha = self.config.action_ema_alpha qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32) qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32) for i, data in enumerate(self._data): # EMA filter: smooth_ctrl ← α·raw + (1-α)·smooth_ctrl self._smooth_ctrl[i] = alpha * actions_np[i] + (1 - alpha) * self._smooth_ctrl[i] data.ctrl[:] = self._smooth_ctrl[i] for _ in range(self.config.substeps): mujoco.mj_step(self._model, data) qpos_batch[i] = data.qpos qvel_batch[i] = data.qvel return ( torch.from_numpy(qpos_batch).to(self.device), torch.from_numpy(qvel_batch).to(self.device), ) def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ids = env_ids.cpu().numpy() n = len(ids) qpos_batch = np.zeros((n, self._nq), dtype=np.float32) qvel_batch = np.zeros((n, self._nv), dtype=np.float32) default_qpos = self.env.get_default_qpos(self._nq) for i, env_id in enumerate(ids): data = self._data[env_id] mujoco.mj_resetData(self._model, data) # Set initial pose (env-specific, e.g. pendulum hanging down) if default_qpos is not None: data.qpos[:] = default_qpos # Add small random perturbation for exploration data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq) data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv) # Reset smoothed ctrl so motor starts from rest self._smooth_ctrl[env_id][:] = 0.0 qpos_batch[i] = data.qpos qvel_batch[i] = data.qvel return ( torch.from_numpy(qpos_batch).to(self.device), torch.from_numpy(qvel_batch).to(self.device), ) def render(self, env_idx: int = 0) -> np.ndarray: """Offscreen render of a single environment.""" if not hasattr(self, "_offscreen_renderer"): self._offscreen_renderer = mujoco.Renderer( self._model, width=640, height=480, ) mujoco.mj_forward(self._model, self._data[env_idx]) self._offscreen_renderer.update_scene(self._data[env_idx]) return self._offscreen_renderer.render()