better robot joint loading

This commit is contained in:
2026-03-09 22:17:28 +01:00
parent 9be07d9186
commit 70cd2cdd7d
13 changed files with 215 additions and 128 deletions

View File

@@ -0,0 +1,10 @@
# Classic cartpole — robot hardware config.
urdf: cartpole.urdf
actuators:
- joint: cart_joint
type: motor
gear: 10.0
ctrl_range: [-1.0, 1.0]
damping: 0.05

View File

@@ -0,0 +1,15 @@
# Rotary cartpole (Furuta pendulum) — robot hardware config.
# Lives next to the URDF so all robot-specific settings are in one place.
urdf: rotary_cartpole.urdf
actuators:
- joint: motor_joint
type: motor # direct torque control
gear: 0.5 # torque multiplier
ctrl_range: [-1.0, 1.0]
damping: 0.1 # motor friction / back-EMF
joints:
pendulum_joint:
damping: 0.0001 # bearing friction

View File

@@ -1,12 +1,7 @@
max_steps: 500 max_steps: 500
robot_path: assets/cartpole
angle_threshold: 0.418 angle_threshold: 0.418
cart_limit: 2.4 cart_limit: 2.4
reward_alive: 1.0 reward_alive: 1.0
reward_pole_upright_scale: 1.0 reward_pole_upright_scale: 1.0
reward_action_penalty_scale: 0.01 reward_action_penalty_scale: 0.01
model_path: assets/cartpole/cartpole.urdf
actuators:
- joint: cart_joint
gear: 10.0
ctrl_range: [-1.0, 1.0]
damping: 0.05

View File

@@ -1,8 +1,3 @@
max_steps: 1000 max_steps: 1000
model_path: assets/rotary_cartpole/rotary_cartpole.urdf robot_path: assets/rotary_cartpole
reward_upright_scale: 1.0 reward_upright_scale: 1.0
actuators:
- joint: motor_joint
gear: 0.5
ctrl_range: [-1.0, 1.0]
damping: 0.1

View File

@@ -10,4 +10,5 @@ clearml
imageio imageio
imageio-ffmpeg imageio-ffmpeg
structlog structlog
pyyaml
pytest pytest

View File

@@ -5,30 +5,20 @@ from gymnasium import spaces
import torch import torch
import pathlib import pathlib
from src.core.robot import RobotConfig, load_robot_config
T = TypeVar("T") T = TypeVar("T")
@dataclasses.dataclass
class ActuatorConfig:
"""Actuator definition — maps a joint to a motor with gear ratio and control limits.
Kept in the env config (not runner config) because actuators define what the robot
can do, which determines action space — a task-level concept.
This mirrors Isaac Lab's pattern of separating actuator config from the robot file."""
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
@dataclasses.dataclass @dataclasses.dataclass
class BaseEnvConfig: class BaseEnvConfig:
max_steps: int = 1000 max_steps: int = 1000
model_path: pathlib.Path | None = None robot_path: str = "" # directory containing robot.yaml + URDF
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
class BaseEnv(abc.ABC, Generic[T]): class BaseEnv(abc.ABC, Generic[T]):
def __init__(self, config: BaseEnvConfig): def __init__(self, config: BaseEnvConfig):
self.config = config self.config = config
self.robot: RobotConfig = load_robot_config(config.robot_path)
@property @property
@abc.abstractmethod @abc.abstractmethod

23
src/core/registry.py Normal file
View File

@@ -0,0 +1,23 @@
"""Shared env registry and builder used by train.py and viz.py."""
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
return env_cls(config_cls(**env_dict))

101
src/core/robot.py Normal file
View File

@@ -0,0 +1,101 @@
"""Robot hardware configuration — loaded from robot.yaml next to the URDF.
Separates robot hardware (actuators, joint tuning) from task config
(rewards, episode length) and from the URDF (clean CAD export).
Usage:
robot = load_robot_config(Path("assets/rotary_cartpole"))
# robot.urdf_path → resolved absolute path to the URDF
# robot.actuators → list of ActuatorConfig
# robot.joints → dict of per-joint overrides
"""
import dataclasses
from pathlib import Path
import structlog
import yaml
log = structlog.get_logger()
@dataclasses.dataclass
class ActuatorConfig:
"""Motor/actuator attached to a joint.
type:
motor — direct torque control (ctrl = normalised torque)
position — PD position servo (ctrl = target angle, needs kp)
velocity — P velocity servo (ctrl = target velocity, needs kp)
"""
joint: str = ""
type: str = "motor"
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
damping: float = 0.05
kp: float = 0.0 # proportional gain (position / velocity actuators)
kv: float = 0.0 # derivative gain (position actuators)
@dataclasses.dataclass
class JointConfig:
"""Per-joint overrides applied on top of the URDF values."""
damping: float | None = None
@dataclasses.dataclass
class RobotConfig:
"""Complete robot hardware description."""
urdf_path: Path = dataclasses.field(default_factory=lambda: Path())
actuators: list[ActuatorConfig] = dataclasses.field(default_factory=list)
joints: dict[str, JointConfig] = dataclasses.field(default_factory=dict)
def load_robot_config(robot_dir: str | Path) -> RobotConfig:
"""Load robot.yaml from a directory and resolve the URDF path.
Expected layout:
robot_dir/
robot.yaml ← hardware config
some_robot.urdf ← CAD export
meshes/ ← optional mesh files
"""
robot_dir = Path(robot_dir).resolve()
yaml_path = robot_dir / "robot.yaml"
if not yaml_path.exists():
raise FileNotFoundError(f"Robot config not found: {yaml_path}")
raw = yaml.safe_load(yaml_path.read_text())
# Resolve URDF path relative to robot.yaml directory
urdf_filename = raw.get("urdf", "")
if not urdf_filename:
raise ValueError(f"robot.yaml must specify 'urdf' filename: {yaml_path}")
urdf_path = robot_dir / urdf_filename
if not urdf_path.exists():
raise FileNotFoundError(f"URDF not found: {urdf_path}")
# Parse actuators
actuators = []
for a in raw.get("actuators", []):
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
actuators.append(ActuatorConfig(**a))
# Parse joint overrides
joints = {}
for name, jcfg in raw.get("joints", {}).items():
joints[name] = JointConfig(**jcfg)
config = RobotConfig(
urdf_path=urdf_path,
actuators=actuators,
joints=joints,
)
log.debug("robot_config_loaded", robot_dir=str(robot_dir),
urdf=urdf_filename, num_actuators=len(actuators),
joint_overrides=list(joints.keys()))
return config

View File

@@ -53,9 +53,10 @@ class BaseRunner(abc.ABC, Generic[T]):
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
... ...
@abc.abstractmethod
def _sim_close(self) -> None: def _sim_close(self) -> None:
... """Release simulator resources. Override for extra cleanup."""
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device) all_ids = torch.arange(self.num_envs, device=self.device)

View File

@@ -29,7 +29,7 @@ import numpy as np
from src.core.env import BaseEnv from src.core.env import BaseEnv
from src.core.runner import BaseRunner, BaseRunnerConfig from src.core.runner import BaseRunner, BaseRunnerConfig
from src.runners.mujoco import MuJoCoRunner # reuse _load_model_with_actuators from src.runners.mujoco import MuJoCoRunner # reuse _load_model
log = structlog.get_logger() log = structlog.get_logger()
@@ -64,14 +64,8 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
# ── Initialization ─────────────────────────────────────────────── # ── Initialization ───────────────────────────────────────────────
def _sim_initialize(self, config: MJXRunnerConfig) -> None: def _sim_initialize(self, config: MJXRunnerConfig) -> None:
model_path = self.env.config.model_path
if model_path is None:
raise ValueError("model_path must be specified")
# Step 1: Load CPU model (reuses URDF → MJCF → actuator injection) # Step 1: Load CPU model (reuses URDF → MJCF → actuator injection)
self._mj_model = MuJoCoRunner._load_model_with_actuators( self._mj_model = MuJoCoRunner._load_model(self.env.robot)
str(model_path), self.env.config.actuators,
)
self._mj_model.opt.timestep = config.dt self._mj_model.opt.timestep = config.dt
self._nq = self._mj_model.nq self._nq = self._mj_model.nq
self._nv = self._mj_model.nv self._nv = self._mj_model.nv
@@ -207,10 +201,6 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
rv = self._batch_data.qvel[ids_np].astype(jnp.float32) rv = self._batch_data.qvel[ids_np].astype(jnp.float32)
return torch.from_dlpack(rq), torch.from_dlpack(rv) return torch.from_dlpack(rq), torch.from_dlpack(rv)
def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
# ── Rendering ──────────────────────────────────────────────────── # ── Rendering ────────────────────────────────────────────────────
def render(self, env_idx: int = 0) -> np.ndarray: def render(self, env_idx: int = 0) -> np.ndarray:

View File

@@ -6,7 +6,8 @@ import mujoco
import numpy as np import numpy as np
import torch import torch
from src.core.env import BaseEnv, ActuatorConfig from src.core.env import BaseEnv
from src.core.robot import RobotConfig
from src.core.runner import BaseRunner, BaseRunnerConfig from src.core.runner import BaseRunner, BaseRunnerConfig
@dataclasses.dataclass @dataclasses.dataclass
@@ -30,18 +31,18 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
return torch.device(self.config.device) return torch.device(self.config.device)
@staticmethod @staticmethod
def _load_model_with_actuators(model_path: str, actuators: list[ActuatorConfig]) -> mujoco.MjModel: def _load_model(robot: RobotConfig) -> mujoco.MjModel:
"""Load a URDF (or MJCF) file and programmatically inject actuators. """Load a URDF (or MJCF) and apply robot.yaml settings.
Two-step approach required because MuJoCo's URDF parser ignores Two-step approach required because MuJoCo's URDF parser ignores
<actuator> in the <mujoco> extension block: <actuator> in the <mujoco> extension block:
1. Load the URDF → MuJoCo converts it to internal MJCF 1. Load the URDF → MuJoCo converts it to internal MJCF
2. Export the MJCF XML, add <actuator> elements, reload 2. Export the MJCF XML, inject actuators + joint overrides, reload
This keeps the URDF clean and standard — actuator config lives in This keeps the URDF clean (re-exportable from CAD) — all hardware
the env config (Isaac Lab pattern), not in the robot file. tuning lives in robot.yaml.
""" """
abs_path = Path(model_path).resolve() abs_path = robot.urdf_path.resolve()
model_dir = abs_path.parent model_dir = abs_path.parent
is_urdf = abs_path.suffix.lower() == ".urdf" is_urdf = abs_path.suffix.lower() == ".urdf"
@@ -74,33 +75,45 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
else: else:
model_raw = mujoco.MjModel.from_xml_path(str(abs_path)) model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
if not actuators: if not robot.actuators and not robot.joints:
return model_raw return model_raw
# Step 2: Export internal MJCF representation (save next to original # Step 2: Export internal MJCF, inject actuators + joint overrides, reload
# model so relative mesh/asset paths resolve correctly on reload)
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml" tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
try: try:
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw) mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
mjcf_str = tmp_mjcf.read_text() mjcf_str = tmp_mjcf.read_text()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
# at steady state, vel_max = torque / damping.
root = ET.fromstring(mjcf_str) root = ET.fromstring(mjcf_str)
act_elem = ET.SubElement(root, "actuator")
for act in actuators:
ET.SubElement(act_elem, "motor", attrib={
"name": f"{act.joint}_motor",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
})
# Add damping to actuated joints to limit max speed and # ── Inject actuators ────────────────────────────────────
# mimic real motor friction / back-EMF. if robot.actuators:
# vel_max ≈ max_torque / damping act_elem = ET.SubElement(root, "actuator")
joint_damping = {a.joint: a.damping for a in actuators} for act in robot.actuators:
attribs = {
"name": f"{act.joint}_{act.type}",
"joint": act.joint,
"gear": str(act.gear),
"ctrlrange": f"{act.ctrl_range[0]} {act.ctrl_range[1]}",
}
if act.type == "position":
attribs["kp"] = str(act.kp)
if act.kv > 0:
attribs["kv"] = str(act.kv)
ET.SubElement(act_elem, "position", attrib=attribs)
elif act.type == "velocity":
attribs["kp"] = str(act.kp)
ET.SubElement(act_elem, "velocity", attrib=attribs)
else: # motor (default)
ET.SubElement(act_elem, "motor", attrib=attribs)
# ── Apply joint overrides from robot.yaml ───────────────
# Merge actuator damping + explicit joint overrides
joint_damping = {a.joint: a.damping for a in robot.actuators}
for name, jcfg in robot.joints.items():
if jcfg.damping is not None:
joint_damping[name] = jcfg.damping
for body in root.iter("body"): for body in root.iter("body"):
for jnt in body.findall("joint"): for jnt in body.findall("joint"):
name = jnt.get("name") name = jnt.get("name")
@@ -115,6 +128,15 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
geom.set("contype", "0") geom.set("contype", "0")
geom.set("conaffinity", "0") geom.set("conaffinity", "0")
# Harden joint limits: MuJoCo's default soft limits are too
# weak and allow overshoot. Negative solref = hard constraint
# (direct stiffness/damping instead of impedance match).
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("limited") == "true" or jnt.get("range"):
jnt.set("solreflimit", "-1000 -100")
jnt.set("solimplimit", "0.95 0.99 0.001")
# Step 4: Write modified MJCF and reload from file path # Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location) # (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode") modified_xml = ET.tostring(root, encoding="unicode")
@@ -124,12 +146,7 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
tmp_mjcf.unlink(missing_ok=True) tmp_mjcf.unlink(missing_ok=True)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None: def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path self._model = self._load_model(self.env.robot)
if model_path is None:
raise ValueError("model_path must be specified in the environment config")
actuators = self.env.config.actuators
self._model = self._load_model_with_actuators(str(model_path), actuators)
self._model.opt.timestep = config.dt self._model.opt.timestep = config.dt
self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)] self._data: list[mujoco.MjData] = [mujoco.MjData(self._model) for _ in range(config.num_envs)]
@@ -195,10 +212,6 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
torch.from_numpy(qvel_batch).to(self.device), torch.from_numpy(qvel_batch).to(self.device),
) )
def _sim_close(self) -> None:
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close()
def render(self, env_idx: int = 0) -> np.ndarray: def render(self, env_idx: int = 0) -> np.ndarray:
"""Offscreen render of a single environment.""" """Offscreen render of a single environment."""
if not hasattr(self, "_offscreen_renderer"): if not hasattr(self, "_offscreen_renderer"):

View File

@@ -13,39 +13,13 @@ from clearml import Task
from hydra.core.hydra_config import HydraConfig from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig from src.core.env import BaseEnv
from src.core.registry import build_env
from src.core.runner import BaseRunner from src.core.runner import BaseRunner
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
from src.training.trainer import Trainer, TrainerConfig from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger() logger = structlog.get_logger()
# ── env registry ──────────────────────────────────────────────────────
# Maps Hydra config-group name → (EnvClass, ConfigClass)
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
"""Instantiate the right env + config from the Hydra config-group name."""
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
# Convert actuator dicts → ActuatorConfig objects
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── runner registry ─────────────────────────────────────────────────── # ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass) # Maps Hydra config-group name → (RunnerClass, ConfigClass)
@@ -123,7 +97,7 @@ def main(cfg: DictConfig) -> None:
task = _init_clearml(choices, remote=remote) task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "cartpole")
env = _build_env(env_name, cfg) env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg) runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict) trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config) trainer = Trainer(runner=runner, config=trainer_config)

25
viz.py
View File

@@ -20,32 +20,11 @@ import torch
from hydra.core.hydra_config import HydraConfig from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig from src.core.registry import build_env
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
logger = structlog.get_logger() logger = structlog.get_logger()
# ── registry (same as train.py) ──────────────────────────────────────
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── keyboard state ─────────────────────────────────────────────────── # ── keyboard state ───────────────────────────────────────────────────
_action_val = [0.0] # mutable container shared with callback _action_val = [0.0] # mutable container shared with callback
@@ -72,7 +51,7 @@ def main(cfg: DictConfig) -> None:
env_name = choices.get("env", "cartpole") env_name = choices.get("env", "cartpole")
# Build env + runner (single env for viz) # Build env + runner (single env for viz)
env = _build_env(env_name, cfg) env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1 runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict)) runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))