initial commit

This commit is contained in:
2026-03-06 22:19:44 +01:00
commit c8f28ffbcc
17 changed files with 811 additions and 0 deletions

155
src/runners/mujoco.py Normal file
View File

@@ -0,0 +1,155 @@
import dataclasses
import tempfile
import xml.etree.ElementTree as ET
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
import mujoco.viewer
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
num_envs: int = 16
device: str = "cpu"
dt: float = 0.02
substeps: int = 2
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
super().__init__(env, config)
@property
def num_envs(self) -> int:
return self.config.num_envs
@property
def device(self) -> torch.device:
return torch.device(self.config.device)
@staticmethod
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
"""Load a URDF (or MJCF) file and programmatically inject actuators.
Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, add <actuator> elements, reload
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
# Step 1: Load URDF/MJCF as-is (no actuators)
model_raw = mujoco.MjModel.from_xml_path(model_path)
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation
tmp_mjcf = tempfile.mktemp(suffix=".xml")
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
finally:
import os
os.unlink(tmp_mjcf)
# Step 3: Inject actuators into the MJCF XML
root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Step 4: Reload from modified MJCF
modified_xml = ET.tostring(root, encoding="unicode")
return mujoco.MjModel.from_xml_string(modified_xml)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified in the environment config")
actuators = self.env.config.actuators
self._model = self._load_model_with_actuators(str(model_path), actuators)
self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
self._nq = self._model.nq
self._nv = self._model.nv
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
actions_np: np.ndarray = actions.cpu().numpy()
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
for i, data in enumerate(self._data):
data.ctrl[:] = actions_np[i]
for _ in range(self.config.substeps):
mujoco.mj_step(self._model, data)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
ids = env_ids.cpu().numpy()
n = len(ids)
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
for i, env_id in enumerate(ids):
data = self._data[env_id]
mujoco.mj_resetData(self._model, data)
# Add small random perturbation so the pole doesn't start perfectly upright
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
qpos_batch[i] = data.qpos
qvel_batch[i] = data.qvel
return (
torch.from_numpy(qpos_batch).to(self.device),
torch.from_numpy(qvel_batch).to(self.device),
)
def _sim_close(self) -> None:
if hasattr(self, "_viewer") and self._viewer is not None:
self._viewer.close()
self._viewer = None
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
self._offscreen_renderer = None
self._data.clear()
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
if mode == "human":
if not hasattr(self, "_viewer") or self._viewer is None:
self._viewer = mujoco.viewer.launch_passive(
self._model, self._data[env_idx]
)
# Update visual geometry from current physics state
mujoco.mj_forward(self._model, self._data[env_idx])
self._viewer.sync()
return None
elif mode == "rgb_array":
# Cache the offscreen renderer to avoid create/destroy overhead
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
self._offscreen_renderer.update_scene(self._data[env_idx])
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
return torch.from_numpy(pixels)