✨ update urdf and dependencies
This commit is contained in:
@@ -36,9 +36,9 @@
|
|||||||
<link name="arm">
|
<link name="arm">
|
||||||
<inertial>
|
<inertial>
|
||||||
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
<origin xyz="0.00005 0.0065 0.00563" rpy="0 0 0"/>
|
||||||
<mass value="0.150"/>
|
<mass value="0.010"/>
|
||||||
<inertia ixx="4.05e-05" iyy="1.17e-05" izz="3.66e-05"
|
<inertia ixx="2.70e-06" iyy="7.80e-07" izz="2.44e-06"
|
||||||
ixy="0.0" iyz="1.08e-07" ixz="0.0"/>
|
ixy="0.0" iyz="7.20e-08" ixz="0.0"/>
|
||||||
</inertial>
|
</inertial>
|
||||||
<visual>
|
<visual>
|
||||||
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
<origin xyz="0.006947 -0.00395 -0.14796" rpy="0 0 0"/>
|
||||||
@@ -73,7 +73,7 @@
|
|||||||
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
Tip at (0.07, -0.07, 0) → 45° diagonal in +X/-Y.
|
||||||
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
CoM = (5×0.035+10×0.07)/15 = 0.0583 along both +X and -Y.
|
||||||
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
Inertia tensor rotated 45° to match diagonal rod axis. -->
|
||||||
<origin xyz="0.0583 -0.0583 0.0" rpy="0 0 0"/>
|
<origin xyz="0.1583 -0.0983 -0.0" rpy="0 0 0"/>
|
||||||
<mass value="0.015"/>
|
<mass value="0.015"/>
|
||||||
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
<inertia ixx="6.16e-06" iyy="6.16e-06" izz="1.23e-05"
|
||||||
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
ixy="6.10e-06" iyz="0.0" ixz="0.0"/>
|
||||||
@@ -93,13 +93,14 @@
|
|||||||
</link>
|
</link>
|
||||||
|
|
||||||
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
<!-- Pendulum joint: arm → pendulum, bearing axis along Y.
|
||||||
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off). -->
|
Joint origin corrected from mesh analysis (Fusion2URDF was 180mm off).
|
||||||
|
rpy pitch +90° so qpos=0 = pendulum hanging down (gravity-stable). -->
|
||||||
<joint name="pendulum_joint" type="continuous">
|
<joint name="pendulum_joint" type="continuous">
|
||||||
<origin xyz="0.000052 0.019274 0.014993" rpy="0 0 0"/>
|
<origin xyz="0.000052 0.019274 0.014993" rpy="0 1.5708 0"/>
|
||||||
<parent link="arm"/>
|
<parent link="arm"/>
|
||||||
<child link="pendulum"/>
|
<child link="pendulum"/>
|
||||||
<axis xyz="0 -1 0"/>
|
<axis xyz="0 -1 0"/>
|
||||||
<dynamics damping="0.0005"/>
|
<dynamics damping="0.0001"/>
|
||||||
</joint>
|
</joint>
|
||||||
|
|
||||||
</robot>
|
</robot>
|
||||||
|
|||||||
1
configs/env/cartpole.yaml
vendored
1
configs/env/cartpole.yaml
vendored
@@ -9,3 +9,4 @@ actuators:
|
|||||||
- joint: cart_joint
|
- joint: cart_joint
|
||||||
gear: 10.0
|
gear: 10.0
|
||||||
ctrl_range: [-1.0, 1.0]
|
ctrl_range: [-1.0, 1.0]
|
||||||
|
damping: 0.05
|
||||||
|
|||||||
3
configs/env/rotary_cartpole.yaml
vendored
3
configs/env/rotary_cartpole.yaml
vendored
@@ -3,5 +3,6 @@ model_path: assets/rotary_cartpole/rotary_cartpole.urdf
|
|||||||
reward_upright_scale: 1.0
|
reward_upright_scale: 1.0
|
||||||
actuators:
|
actuators:
|
||||||
- joint: motor_joint
|
- joint: motor_joint
|
||||||
gear: 15.0
|
gear: 0.5
|
||||||
ctrl_range: [-1.0, 1.0]
|
ctrl_range: [-1.0, 1.0]
|
||||||
|
damping: 0.1
|
||||||
@@ -9,7 +9,8 @@ learning_rate: 0.0003
|
|||||||
clip_ratio: 0.2
|
clip_ratio: 0.2
|
||||||
value_loss_scale: 0.5
|
value_loss_scale: 0.5
|
||||||
entropy_loss_scale: 0.05
|
entropy_loss_scale: 0.05
|
||||||
log_interval: 10
|
log_interval: 1000
|
||||||
|
checkpoint_interval: 50000
|
||||||
|
|
||||||
# ClearML remote execution (GPU worker)
|
# ClearML remote execution (GPU worker)
|
||||||
remote: false
|
remote: false
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ skrl[torch]
|
|||||||
clearml
|
clearml
|
||||||
imageio
|
imageio
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
|
structlog
|
||||||
pytest
|
pytest
|
||||||
@@ -17,6 +17,7 @@ class ActuatorConfig:
|
|||||||
joint: str = ""
|
joint: str = ""
|
||||||
gear: float = 1.0
|
gear: float = 1.0
|
||||||
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
ctrl_range: tuple[float, float] = (-1.0, 1.0)
|
||||||
|
damping: float = 0.05 # joint damping — limits max speed: vel_max ≈ torque / damping
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|||||||
@@ -66,21 +66,14 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|||||||
], dim=-1)
|
], dim=-1)
|
||||||
|
|
||||||
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
def compute_rewards(self, state: RotaryCartPoleState, actions: torch.Tensor) -> torch.Tensor:
|
||||||
# height: sin(θ) → -1 (down) to +1 (up)
|
# Upright reward: -cos(θ) ∈ [-1, +1]
|
||||||
height = torch.sin(state.pendulum_angle)
|
upright = -torch.cos(state.pendulum_angle)
|
||||||
|
|
||||||
# Upright reward: strongly rewards being near vertical.
|
# Velocity penalties — make spinning expensive but allow swing-up
|
||||||
# Uses cos(θ - π/2) = sin(θ), squared and scaled so:
|
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
||||||
# down (h=-1): 0.0
|
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
||||||
# horiz (h= 0): 0.0
|
|
||||||
# up (h=+1): 1.0
|
|
||||||
# Only kicks in above horizontal, so swing-up isn't penalised.
|
|
||||||
upright_reward = torch.clamp(height, 0.0, 1.0) ** 2
|
|
||||||
|
|
||||||
# Motor effort penalty: small cost to avoid bang-bang control.
|
return upright - pend_vel_penalty - motor_vel_penalty
|
||||||
effort_penalty = 0.001 * actions.squeeze(-1) ** 2
|
|
||||||
|
|
||||||
return 5.0 * upright_reward - effort_penalty
|
|
||||||
|
|
||||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||||
# No early termination — episode runs for max_steps (truncation only).
|
# No early termination — episode runs for max_steps (truncation only).
|
||||||
@@ -88,7 +81,5 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
|||||||
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
return torch.zeros_like(state.motor_angle, dtype=torch.bool)
|
||||||
|
|
||||||
def get_default_qpos(self, nq: int) -> list[float] | None:
|
def get_default_qpos(self, nq: int) -> list[float] | None:
|
||||||
# The STL mesh is horizontal at qpos=0.
|
# qpos=0 = pendulum hanging down (joint frame rotated in URDF).
|
||||||
# Pendulum hangs down at θ = -π/2 (sin(-π/2) = -1).
|
return None
|
||||||
import math
|
|
||||||
return [0.0, -math.pi / 2]
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mujoco
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from src.core.env import BaseEnv, ActuatorConfig
|
from src.core.env import BaseEnv, ActuatorConfig
|
||||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import mujoco
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
class MuJoCoRunnerConfig(BaseRunnerConfig):
|
||||||
@@ -39,9 +41,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
This keeps the URDF clean and standard — actuator config lives in
|
This keeps the URDF clean and standard — actuator config lives in
|
||||||
the env config (Isaac Lab pattern), not in the robot file.
|
the env config (Isaac Lab pattern), not in the robot file.
|
||||||
"""
|
"""
|
||||||
abs_path = os.path.abspath(model_path)
|
abs_path = Path(model_path).resolve()
|
||||||
model_dir = os.path.dirname(abs_path)
|
model_dir = abs_path.parent
|
||||||
is_urdf = abs_path.lower().endswith(".urdf")
|
is_urdf = abs_path.suffix.lower() == ".urdf"
|
||||||
|
|
||||||
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
# MuJoCo's URDF parser strips directory prefixes from mesh filenames,
|
||||||
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
# so we inject a <mujoco><compiler meshdir="..."/> block into a
|
||||||
@@ -53,9 +55,9 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
meshdir = None
|
meshdir = None
|
||||||
for mesh_el in root.iter("mesh"):
|
for mesh_el in root.iter("mesh"):
|
||||||
fn = mesh_el.get("filename", "")
|
fn = mesh_el.get("filename", "")
|
||||||
dirname = os.path.dirname(fn)
|
parent = str(Path(fn).parent)
|
||||||
if dirname:
|
if parent and parent != ".":
|
||||||
meshdir = dirname
|
meshdir = parent
|
||||||
break
|
break
|
||||||
if meshdir:
|
if meshdir:
|
||||||
mj_ext = ET.SubElement(root, "mujoco")
|
mj_ext = ET.SubElement(root, "mujoco")
|
||||||
@@ -63,25 +65,24 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
"meshdir": meshdir,
|
"meshdir": meshdir,
|
||||||
"balanceinertia": "true",
|
"balanceinertia": "true",
|
||||||
})
|
})
|
||||||
tmp_urdf = os.path.join(model_dir, "_tmp_mujoco_load.urdf")
|
tmp_urdf = model_dir / "_tmp_mujoco_load.urdf"
|
||||||
tree.write(tmp_urdf, xml_declaration=True, encoding="unicode")
|
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||||||
try:
|
try:
|
||||||
model_raw = mujoco.MjModel.from_xml_path(tmp_urdf)
|
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_urdf)
|
tmp_urdf.unlink()
|
||||||
else:
|
else:
|
||||||
model_raw = mujoco.MjModel.from_xml_path(abs_path)
|
model_raw = mujoco.MjModel.from_xml_path(str(abs_path))
|
||||||
|
|
||||||
if not actuators:
|
if not actuators:
|
||||||
return model_raw
|
return model_raw
|
||||||
|
|
||||||
# Step 2: Export internal MJCF representation (save next to original
|
# Step 2: Export internal MJCF representation (save next to original
|
||||||
# model so relative mesh/asset paths resolve correctly on reload)
|
# model so relative mesh/asset paths resolve correctly on reload)
|
||||||
tmp_mjcf = os.path.join(model_dir, "_tmp_actuator_inject.xml")
|
tmp_mjcf = model_dir / "_tmp_actuator_inject.xml"
|
||||||
try:
|
try:
|
||||||
mujoco.mj_saveLastXML(tmp_mjcf, model_raw)
|
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||||
with open(tmp_mjcf) as f:
|
mjcf_str = tmp_mjcf.read_text()
|
||||||
mjcf_str = f.read()
|
|
||||||
|
|
||||||
# Step 3: Inject actuators into the MJCF XML
|
# Step 3: Inject actuators into the MJCF XML
|
||||||
# Use torque actuator. Speed is limited by joint damping:
|
# Use torque actuator. Speed is limited by joint damping:
|
||||||
@@ -98,12 +99,13 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
|
|
||||||
# Add damping to actuated joints to limit max speed and
|
# Add damping to actuated joints to limit max speed and
|
||||||
# mimic real motor friction / back-EMF.
|
# mimic real motor friction / back-EMF.
|
||||||
# vel_max ≈ max_torque / damping (e.g. 1.0 / 0.05 = 20 rad/s)
|
# vel_max ≈ max_torque / damping
|
||||||
actuated_joints = {a.joint for a in actuators}
|
joint_damping = {a.joint: a.damping for a in actuators}
|
||||||
for body in root.iter("body"):
|
for body in root.iter("body"):
|
||||||
for jnt in body.findall("joint"):
|
for jnt in body.findall("joint"):
|
||||||
if jnt.get("name") in actuated_joints:
|
name = jnt.get("name")
|
||||||
jnt.set("damping", "0.05")
|
if name in joint_damping:
|
||||||
|
jnt.set("damping", str(joint_damping[name]))
|
||||||
|
|
||||||
# Disable self-collision on all geoms.
|
# Disable self-collision on all geoms.
|
||||||
# URDF mesh convex hulls often overlap at joints (especially
|
# URDF mesh convex hulls often overlap at joints (especially
|
||||||
@@ -116,12 +118,10 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
|||||||
# Step 4: Write modified MJCF and reload from file path
|
# Step 4: Write modified MJCF and reload from file path
|
||||||
# (from_xml_path resolves mesh paths relative to the file location)
|
# (from_xml_path resolves mesh paths relative to the file location)
|
||||||
modified_xml = ET.tostring(root, encoding="unicode")
|
modified_xml = ET.tostring(root, encoding="unicode")
|
||||||
with open(tmp_mjcf, "w") as f:
|
tmp_mjcf.write_text(modified_xml)
|
||||||
f.write(modified_xml)
|
return mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||||||
return mujoco.MjModel.from_xml_path(tmp_mjcf)
|
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(tmp_mjcf):
|
tmp_mjcf.unlink(missing_ok=True)
|
||||||
os.unlink(tmp_mjcf)
|
|
||||||
|
|
||||||
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
def _sim_initialize(self, config: MuJoCoRunnerConfig) -> None:
|
||||||
model_path = self.env.config.model_path
|
model_path = self.env.config.model_path
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class TrainerConfig:
|
|||||||
# Training
|
# Training
|
||||||
total_timesteps: int = 1_000_000
|
total_timesteps: int = 1_000_000
|
||||||
log_interval: int = 10
|
log_interval: int = 10
|
||||||
|
checkpoint_interval: int = 50_000
|
||||||
|
|
||||||
# Video recording (uploaded to ClearML)
|
# Video recording (uploaded to ClearML)
|
||||||
record_video_every: int = 10_000 # 0 = disabled
|
record_video_every: int = 10_000 # 0 = disabled
|
||||||
@@ -196,7 +197,7 @@ class Trainer:
|
|||||||
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
# log_interval=1 → log every PPO update (= every rollout_steps timesteps).
|
||||||
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
agent_cfg["experiment"]["write_interval"] = self.config.log_interval
|
||||||
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
agent_cfg["experiment"]["checkpoint_interval"] = max(
|
||||||
self.config.total_timesteps // 10, self.config.rollout_steps
|
self.config.checkpoint_interval, self.config.rollout_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.agent = PPO(
|
self.agent = PPO(
|
||||||
|
|||||||
16
train.py
16
train.py
@@ -1,4 +1,8 @@
|
|||||||
|
import pathlib
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
import hydra.utils as hydra_utils
|
||||||
|
import structlog
|
||||||
from clearml import Task
|
from clearml import Task
|
||||||
from hydra.core.hydra_config import HydraConfig
|
from hydra.core.hydra_config import HydraConfig
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
@@ -9,6 +13,8 @@ from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
|||||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||||
from src.training.trainer import Trainer, TrainerConfig
|
from src.training.trainer import Trainer, TrainerConfig
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
# ── env registry ──────────────────────────────────────────────────────
|
# ── env registry ──────────────────────────────────────────────────────
|
||||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||||
@@ -52,9 +58,15 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|||||||
tags = [env_name, runner_name, training_name]
|
tags = [env_name, runner_name, training_name]
|
||||||
|
|
||||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||||
|
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
||||||
|
|
||||||
if remote:
|
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||||
task.execute_remotely(queue_name="default")
|
task.set_packages(str(req_file))
|
||||||
|
|
||||||
|
# Execute remotely if requested and running locally
|
||||||
|
if remote and task.running_locally():
|
||||||
|
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||||
|
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|||||||
137
viz.py
Normal file
137
viz.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
mjpython viz.py env=rotary_cartpole
|
||||||
|
mjpython viz.py env=cartpole +com=true
|
||||||
|
|
||||||
|
Controls:
|
||||||
|
Left/Right arrows — apply torque to first actuator
|
||||||
|
R — reset environment
|
||||||
|
Esc / close window — quit
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import mujoco
|
||||||
|
import mujoco.viewer
|
||||||
|
import structlog
|
||||||
|
import torch
|
||||||
|
from hydra.core.hydra_config import HydraConfig
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
|
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
|
||||||
|
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
|
||||||
|
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||||
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
# ── registry (same as train.py) ──────────────────────────────────────
|
||||||
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||||
|
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||||
|
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||||
|
if env_name not in ENV_REGISTRY:
|
||||||
|
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||||
|
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||||
|
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||||
|
if "actuators" in env_dict:
|
||||||
|
for a in env_dict["actuators"]:
|
||||||
|
if "ctrl_range" in a:
|
||||||
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||||
|
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||||
|
return env_cls(config_cls(**env_dict))
|
||||||
|
|
||||||
|
|
||||||
|
# ── keyboard state ───────────────────────────────────────────────────
|
||||||
|
_action_val = [0.0] # mutable container shared with callback
|
||||||
|
_action_time = [0.0] # timestamp of last key press
|
||||||
|
_reset_flag = [False]
|
||||||
|
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||||
|
|
||||||
|
|
||||||
|
def _key_callback(keycode: int) -> None:
|
||||||
|
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||||
|
if keycode == 263: # GLFW_KEY_LEFT
|
||||||
|
_action_val[0] = -1.0
|
||||||
|
_action_time[0] = time.time()
|
||||||
|
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||||
|
_action_val[0] = 1.0
|
||||||
|
_action_time[0] = time.time()
|
||||||
|
elif keycode == 82: # GLFW_KEY_R
|
||||||
|
_reset_flag[0] = True
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
choices = HydraConfig.get().runtime.choices
|
||||||
|
env_name = choices.get("env", "cartpole")
|
||||||
|
|
||||||
|
# Build env + runner (single env for viz)
|
||||||
|
env = _build_env(env_name, cfg)
|
||||||
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||||
|
runner_dict["num_envs"] = 1
|
||||||
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||||
|
|
||||||
|
model = runner._model
|
||||||
|
data = runner._data[0]
|
||||||
|
|
||||||
|
# Control period
|
||||||
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||||
|
|
||||||
|
# Launch viewer
|
||||||
|
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||||
|
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||||
|
show_com = cfg.get("com", False)
|
||||||
|
if show_com:
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||||
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||||
|
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
logger.info("viewer_started", env=env_name,
|
||||||
|
controls="Left/Right arrows = torque, R = reset")
|
||||||
|
|
||||||
|
while viewer.is_running():
|
||||||
|
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||||
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||||
|
action_val = _action_val[0]
|
||||||
|
else:
|
||||||
|
action_val = 0.0
|
||||||
|
|
||||||
|
# Reset on R press
|
||||||
|
if _reset_flag[0]:
|
||||||
|
_reset_flag[0] = False
|
||||||
|
obs, _ = runner.reset()
|
||||||
|
step = 0
|
||||||
|
logger.info("reset")
|
||||||
|
|
||||||
|
# Step through runner
|
||||||
|
action = torch.tensor([[action_val]])
|
||||||
|
obs, reward, terminated, truncated, info = runner.step(action)
|
||||||
|
|
||||||
|
# Sync viewer
|
||||||
|
mujoco.mj_forward(model, data)
|
||||||
|
viewer.sync()
|
||||||
|
|
||||||
|
# Print state
|
||||||
|
if step % 25 == 0:
|
||||||
|
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||||
|
for i in range(model.njnt)}
|
||||||
|
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||||
|
action=round(action_val, 1), **joints)
|
||||||
|
|
||||||
|
# Real-time pacing
|
||||||
|
time.sleep(dt_ctrl)
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
runner.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user