update urdf and dependencies

This commit is contained in:
2026-03-09 20:39:02 +01:00
parent c753c369b4
commit 15da0ef2fd
11 changed files with 204 additions and 57 deletions

View File

@@ -17,6 +17,7 @@ class ActuatorConfig:
joint: str = ""
gear: float = 1.0
ctrl_range: tuple[float, float] = (-1.0, 1.0)
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
@dataclasses.dataclass

View File

@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
], dim=-1)
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
# height: sin(θ) → -1 (down) to +1 (up)
height = torch.sin(state.pendulum_angle)
# Upright reward: -cos(θ) ∈ [-1, +1]
upright = -torch.cos(state.pendulum_angle)
# Upright reward: strongly rewards being near vertical.
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
# down (h=-1): 0.0
# horiz (h= 0): 0.0
# up (h=+1): 1.0
# Only kicks in above horizontal, so swing-up isn't penalised.
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
# Velocity penalties — make spinning expensive but allow swing-up
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
motor_vel_penalty = 0.01 * state.motor_vel ** 2
# Motor effort penalty: small cost to avoid bang-bang control.
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
return 5.0 * upright_reward - effort_penalty
return upright - pend_vel_penalty - motor_vel_penalty
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
# No early termination — episode runs for max_steps (truncation only).
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
def get_default_qpos(self, nq: int) -> list[float] | None:
# The STL mesh is horizontal at qpos=0.
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
import math
return [0.0, -math.pi / 2]
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
return None

View File

@@ -1,11 +1,13 @@
import dataclasses
import os
import xml.etree.ElementTree as ET
from pathlib import Path
import mujoco
import numpy as np
import torch
from src.core.env import BaseEnv, ActuatorConfig
from src.core.runner import BaseRunner, BaseRunnerConfig
import torch
import numpy as np
import mujoco
@dataclasses.dataclass
class MuJoCoRunnerConfig(BaseRunnerConfig):
@@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
This keeps the URDF clean and standard — actuator config lives in
the env config (Isaac Lab pattern), not in the robot file.
"""
abs_path = os.path.abspath(model_path)
model_dir = os.path.dirname(abs_path)
is_urdf = abs_path.lower().endswith(".urdf")
abs_path = Path(model_path).resolve()
model_dir = abs_path.parent
is_urdf = abs_path.suffix.lower() == ".urdf"
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
# so we inject a <mujoco><compiler meshdir="..."/> block into a
@@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
meshdir = None
for mesh_el in root.iter("mesh"):
fn = mesh_el.get("filename", "")
dirname = os.path.dirname(fn)
if dirname:
meshdir = dirname
parent = str(Path(fn).parent)
if parent and parent != ".":
meshdir = parent
break
if meshdir:
mj_ext = ET.SubElement(root, "mujoco")
@@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
"meshdir": meshdir,
"balanceinertia": "true",
})
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
try:
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
finally:
os.unlink(tmp_urdf)
tmp_urdf.unlink()
else:
model_raw = mujoco.MjModel.from_xml_path(abs_path)
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
if not actuators:
return model_raw
# Step 2: Export internal MJCF representation (save next to original
# model so relative mesh/asset paths resolve correctly on reload)
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
try:
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
with open(tmp_mjcf) as f:
mjcf_str = f.read()
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
mjcf_str = tmp_mjcf.read_text()
# Step 3: Inject actuators into the MJCF XML
# Use torque actuator. Speed is limited by joint damping:
@@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
# Add damping to actuated joints to limit max speed and
# mimic real motor friction / back-EMF.
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
actuated_joints = {a.joint for a in actuators}
# vel_max ≈ max_torque / damping
joint_damping = {a.joint: a.damping for a in actuators}
for body in root.iter("body"):
for jnt in body.findall("joint"):
if jnt.get("name") in actuated_joints:
jnt.set("damping", "0.05")
name = jnt.get("name")
if name in joint_damping:
jnt.set("damping", str(joint_damping[name]))
# Disable self-collision on all geoms.
# URDF mesh convex hulls often overlap at joints (especially
@@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
# Step 4: Write modified MJCF and reload from file path
# (from_xml_path resolves mesh paths relative to the file location)
modified_xml = ET.tostring(root, encoding="unicode")
with open(tmp_mjcf, "w") as f:
f.write(modified_xml)
return mujoco.MjModel.from_xml_path(tmp_mjcf)
tmp_mjcf.write_text(modified_xml)
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
finally:
if os.path.exists(tmp_mjcf):
os.unlink(tmp_mjcf)
tmp_mjcf.unlink(missing_ok=True)
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
model_path = self.env.config.model_path

View File

@@ -35,6 +35,7 @@ class TrainerConfig:
# Training
total_timesteps: int = 1_000_000
log_interval: int = 10
checkpoint_interval: int = 50_000
# Video recording (uploaded to ClearML)
record_video_every: int = 10_000 # 0 = disabled
@@ -196,7 +197,7 @@ class Trainer:
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
agent_cfg["experiment"]["checkpoint_interval"] = max(
self.config.total_timesteps // 10, self.config.rollout_steps
self.config.checkpoint_interval, self.config.rollout_steps
)
self.agent = PPO(