✨ update urdf and dependencies
This commit is contained in:
@@ -17,6 +17,7 @@ class ActuatorConfig:
|
||||
joint: str = ""
|
||||
gear: float = 1.0
|
||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
], dim=-1)
|
||||
|
||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||
# height: sin(θ) → -1 (down) to +1 (up)
|
||||
height = torch.sin(state.pendulum_angle)
|
||||
# Upright reward: -cos(θ) ∈ [-1, +1]
|
||||
upright = -torch.cos(state.pendulum_angle)
|
||||
|
||||
# Upright reward: strongly rewards being near vertical.
|
||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
||||
# down (h=-1): 0.0
|
||||
# horiz (h= 0): 0.0
|
||||
# up (h=+1): 1.0
|
||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
||||
# Velocity penalties — make spinning expensive but allow swing-up
|
||||
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
||||
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
||||
|
||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
||||
|
||||
return 5.0 * upright_reward - effort_penalty
|
||||
return upright - pend_vel_penalty - motor_vel_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||
|
||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||
# The STL mesh is horizontal at qpos=0.
|
||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
||||
import math
|
||||
return [0.0, -math.pi / 2]
|
||||
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
|
||||
return None
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv, ActuatorConfig
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
import torch
|
||||
import numpy as np
|
||||
import mujoco
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||
@@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
This keeps the URDF clean and standard — actuator config lives in
|
||||
the env config (Isaac Lab pattern), not in the robot file.
|
||||
"""
|
||||
abs_path = os.path.abspath(model_path)
|
||||
model_dir = os.path.dirname(abs_path)
|
||||
is_urdf = abs_path.lower().endswith(".urdf")
|
||||
abs_path = Path(model_path).resolve()
|
||||
model_dir = abs_path.parent
|
||||
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||||
|
||||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||
@@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
dirname = os.path.dirname(fn)
|
||||
if dirname:
|
||||
meshdir = dirname
|
||||
parent = str(Path(fn).parent)
|
||||
if parent and parent != ".":
|
||||
meshdir = parent
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
@@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
"meshdir": meshdir,
|
||||
"balanceinertia": "true",
|
||||
})
|
||||
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
||||
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
||||
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
|
||||
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||||
try:
|
||||
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||||
finally:
|
||||
os.unlink(tmp_urdf)
|
||||
tmp_urdf.unlink()
|
||||
else:
|
||||
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||||
|
||||
if not actuators:
|
||||
return model_raw
|
||||
|
||||
# Step 2: Export internal MJCF representation (save next to original
|
||||
# model so relative mesh/asset paths resolve correctly on reload)
|
||||
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
||||
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
|
||||
try:
|
||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
||||
with open(tmp_mjcf) as f:
|
||||
mjcf_str = f.read()
|
||||
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||
mjcf_str = tmp_mjcf.read_text()
|
||||
|
||||
# Step 3: Inject actuators into the MJCF XML
|
||||
# Use torque actuator. Speed is limited by joint damping:
|
||||
@@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
|
||||
# Add damping to actuated joints to limit max speed and
|
||||
# mimic real motor friction / back-EMF.
|
||||
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
||||
actuated_joints = {a.joint for a in actuators}
|
||||
# vel_max ≈ max_torque / damping
|
||||
joint_damping = {a.joint: a.damping for a in actuators}
|
||||
for body in root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
if jnt.get("name") in actuated_joints:
|
||||
jnt.set("damping", "0.05")
|
||||
name = jnt.get("name")
|
||||
if name in joint_damping:
|
||||
jnt.set("damping", str(joint_damping[name]))
|
||||
|
||||
# Disable self-collision on all geoms.
|
||||
# URDF mesh convex hulls often overlap at joints (especially
|
||||
@@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
# Step 4: Write modified MJCF and reload from file path
|
||||
# (from_xml_path resolves mesh paths relative to the file location)
|
||||
modified_xml = ET.tostring(root, encoding="unicode")
|
||||
with open(tmp_mjcf, "w") as f:
|
||||
f.write(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
||||
tmp_mjcf.write_text(modified_xml)
|
||||
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||||
finally:
|
||||
if os.path.exists(tmp_mjcf):
|
||||
os.unlink(tmp_mjcf)
|
||||
tmp_mjcf.unlink(missing_ok=True)
|
||||
|
||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||
model_path = self.env.config.model_path
|
||||
|
||||
@@ -35,6 +35,7 @@ class TrainerConfig:
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
checkpoint_interval: int = 50_000
|
||||
|
||||
# Video recording (uploaded to ClearML)
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
@@ -196,7 +197,7 @@ class Trainer:
|
||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||
self.config.total_timesteps // 10, self.config.rollout_steps
|
||||
self.config.checkpoint_interval, self.config.rollout_steps
|
||||
)
|
||||
|
||||
self.agent = PPO(
|
||||
|
||||
Reference in New Issue
Block a user