♻️ cleanup
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Serial runner — communicates with real hardware over USB/serial.
|
||||
# Always single-env, CPU-only. Override port on CLI:
|
||||
# python train.py runner=serial runner.port=/dev/ttyUSB0
|
||||
# python scripts/train.py runner=serial runner.port=/dev/ttyUSB0
|
||||
|
||||
num_envs: 1
|
||||
device: cpu
|
||||
|
||||
@@ -170,9 +170,12 @@ def _create_base_task(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=osmesa",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
@@ -214,6 +217,10 @@ def main() -> None:
|
||||
help="Maximum budget (total_timesteps) for promoted trials",
|
||||
)
|
||||
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
||||
parser.add_argument(
|
||||
"--max-consecutive-failures", type=int, default=3,
|
||||
help="Abort HPO after N consecutive trial failures (0 = never abort)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time-limit-hours", type=float, default=72,
|
||||
help="Total wall-clock time limit in hours",
|
||||
@@ -312,6 +319,7 @@ def main() -> None:
|
||||
time_limit_per_job=240, # 4 hours per trial max
|
||||
eta=args.eta,
|
||||
budget_param_name="Hydra/training.total_timesteps",
|
||||
max_consecutive_failures=args.max_consecutive_failures,
|
||||
)
|
||||
|
||||
# Send this HPO controller to a remote services worker
|
||||
|
||||
@@ -8,8 +8,9 @@ _PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
|
||||
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
|
||||
if sys.platform == "linux":
|
||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||
|
||||
import hydra
|
||||
@@ -73,9 +74,12 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=osmesa",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
|
||||
@@ -204,6 +204,12 @@ class OptimizerSMAC(SearchStrategy):
|
||||
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
|
||||
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
|
||||
|
||||
# Consecutive-failure abort: stop HPO if N trials in a row crash
|
||||
self.max_consecutive_failures = int(
|
||||
smac_kwargs.pop("max_consecutive_failures", 3)
|
||||
)
|
||||
self._consecutive_failures = 0
|
||||
|
||||
# Checkpoint continuation tracking: config_key -> {budget: task_id}
|
||||
# Used to find the previous task's checkpoint when promoting a config
|
||||
self.config_to_tasks = {} # config_key -> {budget: task_id}
|
||||
@@ -395,6 +401,25 @@ class OptimizerSMAC(SearchStrategy):
|
||||
if not self.running_tasks and not self.pending_configs:
|
||||
break
|
||||
|
||||
# Abort if too many consecutive trials failed (likely a config bug)
|
||||
if (
|
||||
self.max_consecutive_failures > 0
|
||||
and self._consecutive_failures >= self.max_consecutive_failures
|
||||
):
|
||||
controller.get_logger().report_text(
|
||||
f"ABORTING: {self._consecutive_failures} consecutive trial "
|
||||
f"failures (limit: {self.max_consecutive_failures}). "
|
||||
"Check the trial logs for errors."
|
||||
)
|
||||
# Stop any still-running tasks
|
||||
for tid in list(self.running_tasks):
|
||||
with contextlib.suppress(Exception):
|
||||
t = self._get_task_safe(task_id=tid)
|
||||
if t:
|
||||
t.mark_stopped(force=True)
|
||||
self.running_tasks.clear()
|
||||
break
|
||||
|
||||
# Poll for finished or timed out
|
||||
done = []
|
||||
timed_out = []
|
||||
@@ -463,14 +488,31 @@ class OptimizerSMAC(SearchStrategy):
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
|
||||
# Detect hard-failed tasks (crashed / errored) vs completed
|
||||
task_failed = False
|
||||
if task is not None:
|
||||
st = task.get_status()
|
||||
task_failed = st in (
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.stopped,
|
||||
)
|
||||
|
||||
if task is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
task_failed = True
|
||||
else:
|
||||
res = self._get_objective(task)
|
||||
|
||||
if res is None or res == float("-inf") or res == float("inf"):
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
|
||||
# Track consecutive failures for abort logic
|
||||
if task_failed:
|
||||
self._consecutive_failures += 1
|
||||
else:
|
||||
self._consecutive_failures = 0 # reset on any success
|
||||
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
|
||||
12
train.py
12
train.py
@@ -2,8 +2,9 @@ import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
|
||||
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
|
||||
if sys.platform == "linux":
|
||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||
|
||||
import hydra
|
||||
@@ -71,9 +72,12 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
|
||||
),
|
||||
docker_arguments=[
|
||||
"-e", "MUJOCO_GL=osmesa",
|
||||
],
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
|
||||
247
viz.py
247
viz.py
@@ -1,247 +0,0 @@
|
||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||
|
||||
Usage (simulation):
|
||||
mjpython viz.py env=rotary_cartpole
|
||||
mjpython viz.py env=cartpole +com=true
|
||||
|
||||
Usage (real hardware — digital twin):
|
||||
mjpython viz.py env=rotary_cartpole runner=serial
|
||||
mjpython viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
||||
|
||||
Controls:
|
||||
Left/Right arrows — apply torque to first actuator
|
||||
R — reset environment
|
||||
Esc / close window — quit
|
||||
"""
|
||||
import math
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
_action_time = [0.0] # timestamp of last key press
|
||||
_reset_flag = [False]
|
||||
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||
if keycode == 263: # GLFW_KEY_LEFT
|
||||
_action_val[0] = -1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||
_action_val[0] = 1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
|
||||
|
||||
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
"""Draw an arrow on the motor joint showing applied torque direction."""
|
||||
if abs(action_val) < 0.01 or model.nu == 0:
|
||||
return
|
||||
|
||||
# Get the body that the first actuator's joint belongs to
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
body_id = model.jnt_bodyid[jnt_id]
|
||||
|
||||
# Arrow origin: body position
|
||||
pos = data.xpos[body_id].copy()
|
||||
pos[2] += 0.02 # lift slightly above the body
|
||||
|
||||
# Arrow direction: along joint axis in world frame, scaled by action
|
||||
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||
arrow_len = 0.08 * action_val
|
||||
direction = axis * np.sign(arrow_len)
|
||||
|
||||
# Build rotation matrix: arrow rendered along local z-axis
|
||||
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||
x = np.cross(up, z)
|
||||
x /= np.linalg.norm(x) + 1e-8
|
||||
y = np.cross(z, x)
|
||||
mat = np.column_stack([x, y, z]).flatten()
|
||||
|
||||
# Color: green = positive, red = negative
|
||||
rgba = np.array(
|
||||
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||
pos=pos,
|
||||
mat=mat,
|
||||
rgba=rgba,
|
||||
)
|
||||
viewer.user_scn.ngeom += 1
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
|
||||
if runner_name == "serial":
|
||||
_main_serial(cfg, env_name)
|
||||
else:
|
||||
_main_sim(cfg, env_name)
|
||||
|
||||
|
||||
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
model = runner._model
|
||||
data = runner._data[0]
|
||||
|
||||
# Control period
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
# Launch viewer
|
||||
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
|
||||
logger.info("viewer_started", env=env_name,
|
||||
controls="Left/Right arrows = torque, R = reset")
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
logger.info("reset")
|
||||
|
||||
# Step through runner
|
||||
action = torch.tensor([[action_val]])
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
|
||||
# Sync viewer with action arrow overlay
|
||||
mujoco.mj_forward(model, data)
|
||||
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Print state
|
||||
if step % 25 == 0:
|
||||
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||
for i in range(model.njnt)}
|
||||
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action_val, 1), **joints)
|
||||
|
||||
# Real-time pacing
|
||||
time.sleep(dt_ctrl)
|
||||
step += 1
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
||||
|
||||
The MuJoCo model is loaded for rendering only. Joint positions are
|
||||
read from the ESP32 over serial and applied to the model each frame.
|
||||
Keyboard arrows send motor commands to the real robot.
|
||||
"""
|
||||
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
serial_runner = SerialRunner(
|
||||
env=env, config=SerialRunnerConfig(**runner_dict)
|
||||
)
|
||||
|
||||
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
||||
serial_runner._ensure_viz_model()
|
||||
model = serial_runner._viz_model
|
||||
data = serial_runner._viz_data
|
||||
|
||||
with mujoco.viewer.launch_passive(
|
||||
model, data, key_callback=_key_callback
|
||||
) as viewer:
|
||||
# Show CoM / inertia if requested.
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
logger.info(
|
||||
"viewer_started",
|
||||
env=env_name,
|
||||
mode="serial (digital twin)",
|
||||
port=serial_runner.config.port,
|
||||
controls="Left/Right arrows = motor command, R = reset",
|
||||
)
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from keyboard callback.
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press.
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
logger.info("reset (drive-to-center + settle)")
|
||||
|
||||
# Send motor command to real hardware.
|
||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
||||
serial_runner._send(f"M{motor_speed}")
|
||||
|
||||
# Sync MuJoCo model with real sensor data.
|
||||
serial_runner._sync_viz()
|
||||
|
||||
# Render overlays and sync viewer.
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Real-time pacing (~50 Hz, matches serial dt).
|
||||
time.sleep(serial_runner.config.dt)
|
||||
|
||||
serial_runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user