✨ initial commit
This commit is contained in:
155
src/runners/mujoco.py
Normal file
155
src/runners/mujoco.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import dataclasses
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
num_envs: int = 16
|
||||
device: str = "cpu"
|
||||
dt: float = 0.02
|
||||
substeps: int = 2
|
||||
|
||||
class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
def __init__(self, env: BaseEnv, config: MuJoCoRunnerConfig):
|
||||
super().__init__(env, config)
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return self.config.num_envs
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(self.config.device)
|
||||
|
||||
@staticmethod
|
||||
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel:
|
||||
"""Load a URDF (or MJCF) file and programmatically inject actuators.
|
||||
|
||||
Two-step approach required because MuJoCo's URDF parser ignores
|
||||
<actuator> in the <mujoco> extension block:
|
||||
1. Load the URDF → MuJoCo converts it to internal MJCF
|
||||
2. Export the MJCF XML, add <actuator> elements, reload
|
||||
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
# Step 1: Load URDF/MJCF as-is (no actuators)
|
||||
model_raw = mujoco.MjModel.from_xml_path(model_path)
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation
|
||||
tmp_mjcf = tempfile.mktemp(suffix=".xml")
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp_mjcf)
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
root = ET.fromstring(mjcf_str)
|
||||
act_elem = ET.SubElement(root, "actuator")
|
||||
for act in actuators:
|
||||
ET.SubElement(act_elem, "motor", attrib={
|
||||
"name": f"{act.joint}_motor",
|
||||
"joint": act.joint,
|
||||
"gear": str(act.gear),
|
||||
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
|
||||
})
|
||||
|
||||
# Step 4: Reload from modified MJCF
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
return mujoco.MjModel.from_xml_string(modified_xml)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
if model_path is None:
|
||||
raise ValueError("model_path must be specified in the environment config")
|
||||
|
||||
actuators = self.env.config.actuators
|
||||
self._model = self._load_model_with_actuators(str(model_path), actuators)
|
||||
self._model.opt.timestep = config.dt
|
||||
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
|
||||
|
||||
self._nq = self._model.nq
|
||||
self._nv = self._model.nv
|
||||
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
actions_np: np.ndarray = actions.cpu().numpy()
|
||||
|
||||
qpos_batch = np.zeros((self.num_envs, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((self.num_envs, self._nv), dtype=np.float32)
|
||||
|
||||
for i, data in enumerate(self._data):
|
||||
data.ctrl[:] = actions_np[i]
|
||||
for _ in range(self.config.substeps):
|
||||
mujoco.mj_step(self._model, data)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
return (
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
ids = env_ids.cpu().numpy()
|
||||
n = len(ids)
|
||||
|
||||
qpos_batch = np.zeros((n, self._nq), dtype=np.float32)
|
||||
qvel_batch = np.zeros((n, self._nv), dtype=np.float32)
|
||||
|
||||
for i, env_id in enumerate(ids):
|
||||
data = self._data[env_id]
|
||||
mujoco.mj_resetData(self._model, data)
|
||||
|
||||
# Add small random perturbation so the pole doesn't start perfectly upright
|
||||
data.qpos[:] += np.random.uniform(-0.05, 0.05, size=self._nq)
|
||||
data.qvel[:] += np.random.uniform(-0.05, 0.05, size=self._nv)
|
||||
|
||||
qpos_batch[i] = data.qpos
|
||||
qvel_batch[i] = data.qvel
|
||||
|
||||
return (
|
||||
torch.from_numpy(qpos_batch).to(self.device),
|
||||
torch.from_numpy(qvel_batch).to(self.device),
|
||||
)
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
if hasattr(self, "_viewer") and self._viewer is not None:
|
||||
self._viewer.close()
|
||||
self._viewer = None
|
||||
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
self._offscreen_renderer = None
|
||||
|
||||
self._data.clear()
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
if mode == "human":
|
||||
if not hasattr(self, "_viewer") or self._viewer is None:
|
||||
self._viewer = mujoco.viewer.launch_passive(
|
||||
self._model, self._data[env_idx]
|
||||
)
|
||||
# Update visual geometry from current physics state
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._viewer.sync()
|
||||
return None
|
||||
elif mode == "rgb_array":
|
||||
# Cache the offscreen renderer to avoid create/destroy overhead
|
||||
if not hasattr(self, "_offscreen_renderer") or self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(self._model, height=480, width=640)
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
pixels = self._offscreen_renderer.render().copy() # copy since buffer is reused
|
||||
return torch.from_numpy(pixels)
|
||||
Reference in New Issue
Block a user