♻️ crazy refactor
This commit is contained in:
340
scripts/hpo.py
Normal file
340
scripts/hpo.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""Hyperparameter optimization for RL-Framework using ClearML + SMAC3.
|
||||
|
||||
Automatically creates a base training task (via Task.create), reads HPO
|
||||
search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
|
||||
and launches SMAC3 Successive Halving optimization.
|
||||
|
||||
Usage:
|
||||
python scripts/hpo.py \
|
||||
--env rotary_cartpole \
|
||||
--runner mujoco_single \
|
||||
--training ppo_single \
|
||||
--queue gpu-queue
|
||||
|
||||
# Or use an existing base task:
|
||||
python scripts/hpo.py --base-task-id <TASK_ID>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from clearml.automation import (
|
||||
DiscreteParameterRange,
|
||||
HyperParameterOptimizer,
|
||||
UniformIntegerParameterRange,
|
||||
UniformParameterRange,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def _load_hydra_config(
|
||||
env: str, runner: str, training: str
|
||||
) -> dict:
|
||||
"""Load and merge Hydra configs to extract HPO ranges.
|
||||
|
||||
We read the YAML files directly (without running Hydra) so this script
|
||||
doesn't need @hydra.main — it's a ClearML optimizer, not a training job.
|
||||
"""
|
||||
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||
|
||||
# Load training config (handles defaults: [ppo] inheritance)
|
||||
training_path = configs_dir / "training" / f"{training}.yaml"
|
||||
training_cfg = OmegaConf.load(training_path)
|
||||
|
||||
# If the training config has defaults pointing to a base, load + merge
|
||||
if "defaults" in training_cfg:
|
||||
defaults = OmegaConf.to_container(training_cfg.defaults)
|
||||
base_cfg = OmegaConf.create({})
|
||||
for d in defaults:
|
||||
if isinstance(d, str):
|
||||
base_path = configs_dir / "training" / f"{d}.yaml"
|
||||
if base_path.exists():
|
||||
loaded = OmegaConf.load(base_path)
|
||||
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||
# Remove defaults key and merge
|
||||
training_no_defaults = {
|
||||
k: v for k, v in OmegaConf.to_container(training_cfg).items()
|
||||
if k != "defaults"
|
||||
}
|
||||
training_cfg = OmegaConf.merge(base_cfg, OmegaConf.create(training_no_defaults))
|
||||
|
||||
# Load env config
|
||||
env_path = configs_dir / "env" / f"{env}.yaml"
|
||||
env_cfg = OmegaConf.load(env_path) if env_path.exists() else OmegaConf.create({})
|
||||
|
||||
return {
|
||||
"training": OmegaConf.to_container(training_cfg, resolve=True),
|
||||
"env": OmegaConf.to_container(env_cfg, resolve=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_hyper_parameters(config: dict) -> list:
|
||||
"""Build ClearML parameter ranges from hpo: blocks in config.
|
||||
|
||||
Reads training.hpo and env.hpo dicts and creates appropriate
|
||||
ClearML parameter range objects.
|
||||
|
||||
Each hpo entry can have:
|
||||
{min, max} → UniformParameterRange (float)
|
||||
{min, max, type: int} → UniformIntegerParameterRange
|
||||
{min, max, log: true} → UniformParameterRange with log scale
|
||||
{values: [...]} → DiscreteParameterRange
|
||||
"""
|
||||
params = []
|
||||
|
||||
for section in ("training", "env"):
|
||||
hpo_ranges = config.get(section, {}).get("hpo", {})
|
||||
if not hpo_ranges:
|
||||
continue
|
||||
|
||||
for param_name, spec in hpo_ranges.items():
|
||||
hydra_key = f"Hydra/{section}.{param_name}"
|
||||
|
||||
if "values" in spec:
|
||||
params.append(
|
||||
DiscreteParameterRange(hydra_key, values=spec["values"])
|
||||
)
|
||||
elif "min" in spec and "max" in spec:
|
||||
if spec.get("type") == "int":
|
||||
params.append(
|
||||
UniformIntegerParameterRange(
|
||||
hydra_key,
|
||||
min_value=int(spec["min"]),
|
||||
max_value=int(spec["max"]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
step = spec.get("step", None)
|
||||
params.append(
|
||||
UniformParameterRange(
|
||||
hydra_key,
|
||||
min_value=float(spec["min"]),
|
||||
max_value=float(spec["max"]),
|
||||
step_size=step,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning("skipping_unknown_hpo_spec", param=param_name, spec=spec)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _create_base_task(
|
||||
env: str, runner: str, training: str, queue: str
|
||||
) -> str:
|
||||
"""Create a base ClearML task without executing it.
|
||||
|
||||
Uses Task.create() to register a task pointing at scripts/train.py
|
||||
with the correct Hydra overrides. The HPO optimizer will clone this.
|
||||
"""
|
||||
script_path = str(Path(__file__).resolve().parent / "train.py")
|
||||
project_root = str(Path(__file__).resolve().parent.parent)
|
||||
|
||||
base_task = Task.create(
|
||||
project_name="RL-Framework",
|
||||
task_name=f"{env}-{runner}-{training} (HPO base)",
|
||||
task_type=Task.TaskTypes.training,
|
||||
script=script_path,
|
||||
working_directory=project_root,
|
||||
argparse_args=[
|
||||
f"env={env}",
|
||||
f"runner={runner}",
|
||||
f"training={training}",
|
||||
],
|
||||
add_task_init_call=False,
|
||||
)
|
||||
|
||||
# Set docker config
|
||||
base_task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
),
|
||||
)
|
||||
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
base_task.set_packages(str(req_file))
|
||||
|
||||
task_id = base_task.id
|
||||
logger.info("base_task_created", task_id=task_id, task_name=base_task.name)
|
||||
return task_id
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Hyperparameter optimization for RL-Framework"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-task-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Existing ClearML task ID to use as base (skip auto-creation)",
|
||||
)
|
||||
parser.add_argument("--env", type=str, default="rotary_cartpole")
|
||||
parser.add_argument("--runner", type=str, default="mujoco_single")
|
||||
parser.add_argument("--training", type=str, default="ppo_single")
|
||||
parser.add_argument("--queue", type=str, default="gpu-queue")
|
||||
parser.add_argument(
|
||||
"--max-concurrent", type=int, default=2,
|
||||
help="Maximum concurrent trial tasks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--total-trials", type=int, default=200,
|
||||
help="Total HPO trial budget",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-budget", type=int, default=3,
|
||||
help="Minimum budget (epochs) per trial",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-budget", type=int, default=81,
|
||||
help="Maximum budget (epochs) for promoted trials",
|
||||
)
|
||||
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
||||
parser.add_argument(
|
||||
"--time-limit-hours", type=float, default=72,
|
||||
help="Total wall-clock time limit in hours",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--objective-metric", type=str, default="Reward / Total reward (mean)",
|
||||
help="ClearML scalar metric title to optimize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--objective-series", type=str, default=None,
|
||||
help="ClearML scalar metric series (default: same as title)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--maximize", action="store_true", default=True,
|
||||
help="Maximize the objective (default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minimize", action="store_true", default=False,
|
||||
help="Minimize the objective",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
help="Print search space and exit without running",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
objective_sign = "min" if args.minimize else "max"
|
||||
|
||||
# ── Load config and build search space ────────────────────────
|
||||
config = _load_hydra_config(args.env, args.runner, args.training)
|
||||
hyper_parameters = _build_hyper_parameters(config)
|
||||
|
||||
if not hyper_parameters:
|
||||
logger.error(
|
||||
"no_hpo_ranges_found",
|
||||
hint="Add 'hpo:' blocks to your training and/or env YAML configs",
|
||||
)
|
||||
return
|
||||
|
||||
if args.dry_run:
|
||||
print(f"\nSearch space ({len(hyper_parameters)} parameters):")
|
||||
for p in hyper_parameters:
|
||||
print(f" {p.name}: {p}")
|
||||
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
|
||||
return
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
args.env, args.runner, args.training, args.queue
|
||||
)
|
||||
|
||||
# ── Initialize ClearML HPO task ───────────────────────────────
|
||||
Task.ignore_requirements("torch")
|
||||
task = Task.init(
|
||||
project_name="RL-Framework",
|
||||
task_name=f"HPO {args.env}-{args.runner}-{args.training}",
|
||||
task_type=Task.TaskTypes.optimizer,
|
||||
reuse_last_task_id=False,
|
||||
)
|
||||
task.set_base_docker(
|
||||
docker_image="registry.kube.optimize/worker-image:latest",
|
||||
docker_arguments=[
|
||||
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
||||
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
||||
"-e", "CLEARML_AGENT_FORCE_SYSTEM_SITE_PACKAGES=1",
|
||||
],
|
||||
)
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# ── Build objective metric ────────────────────────────────────
|
||||
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
|
||||
objective_title = args.objective_metric
|
||||
objective_series = args.objective_series or objective_title
|
||||
|
||||
# ── Launch optimizer ──────────────────────────────────────────
|
||||
from src.hpo.smac3 import OptimizerSMAC
|
||||
|
||||
optimizer = HyperParameterOptimizer(
|
||||
base_task_id=base_task_id,
|
||||
hyper_parameters=hyper_parameters,
|
||||
objective_metric_title=objective_title,
|
||||
objective_metric_series=objective_series,
|
||||
objective_metric_sign=objective_sign,
|
||||
optimizer_class=OptimizerSMAC,
|
||||
execution_queue=args.queue,
|
||||
max_number_of_concurrent_tasks=args.max_concurrent,
|
||||
total_max_jobs=args.total_trials,
|
||||
min_iteration_per_job=args.min_budget,
|
||||
max_iteration_per_job=args.max_budget,
|
||||
pool_period_min=1,
|
||||
time_limit_per_job=240, # 4 hours per trial max
|
||||
eta=args.eta,
|
||||
)
|
||||
|
||||
# Send this HPO controller to a remote services worker
|
||||
task.execute_remotely(queue_name="services", exit_process=True)
|
||||
|
||||
# Reporting and time limits
|
||||
optimizer.set_report_period(1)
|
||||
optimizer.set_time_limit(in_minutes=int(args.time_limit_hours * 60))
|
||||
|
||||
# Start and wait
|
||||
optimizer.start()
|
||||
optimizer.wait()
|
||||
|
||||
# Get top experiments
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
top_exp = optimizer.get_top_experiments(top_k=10)
|
||||
logger.info("top_experiments_retrieved", count=len(top_exp))
|
||||
for i, t in enumerate(top_exp):
|
||||
logger.info("top_experiment", rank=i + 1, task_id=t.id, name=t.name)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("retry_get_top_experiments", attempt=attempt + 1, error=str(e))
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(5.0 * (2 ** attempt))
|
||||
else:
|
||||
logger.error("could_not_retrieve_top_experiments")
|
||||
|
||||
optimizer.stop()
|
||||
logger.info("hpo_complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
scripts/sysid.py
Normal file
57
scripts/sysid.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Unified CLI for system identification tools.
|
||||
|
||||
Usage:
|
||||
python scripts/sysid.py capture --robot-path assets/rotary_cartpole --duration 20
|
||||
python scripts/sysid.py optimize --robot-path assets/rotary_cartpole --recording <file>.npz
|
||||
python scripts/sysid.py visualize --recording <file>.npz
|
||||
python scripts/sysid.py export --robot-path assets/rotary_cartpole --result <result>.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
||||
print(
|
||||
"Usage: python scripts/sysid.py <command> [options]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" capture Record real robot trajectory under PRBS excitation\n"
|
||||
" optimize Run CMA-ES parameter optimization\n"
|
||||
" visualize Plot real vs simulated trajectories\n"
|
||||
" export Write tuned URDF + robot.yaml files\n"
|
||||
"\n"
|
||||
"Run 'python scripts/sysid.py <command> --help' for command-specific options."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
command = sys.argv[1]
|
||||
# Remove the subcommand from argv so the module's argparse works normally
|
||||
sys.argv = [f"sysid {command}"] + sys.argv[2:]
|
||||
|
||||
if command == "capture":
|
||||
from src.sysid.capture import main as cmd_main
|
||||
elif command == "optimize":
|
||||
from src.sysid.optimize import main as cmd_main
|
||||
elif command == "visualize":
|
||||
from src.sysid.visualize import main as cmd_main
|
||||
elif command == "export":
|
||||
from src.sysid.export import main as cmd_main
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("Available commands: capture, optimize, visualize, export")
|
||||
sys.exit(1)
|
||||
|
||||
cmd_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
scripts/train.py
Normal file
118
scripts/train.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
# Ensure project root is on sys.path so `src.*` imports work
|
||||
# regardless of which directory the script is invoked from.
|
||||
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.registry import build_env
|
||||
from src.core.runner import BaseRunner
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||
|
||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
|
||||
}
|
||||
|
||||
|
||||
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||
"""Instantiate the right runner from the Hydra config-group name."""
|
||||
if runner_name not in RUNNER_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||
)
|
||||
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
runner_cls = getattr(mod, cls_name)
|
||||
config_cls = getattr(mod, cfg_cls_name)
|
||||
|
||||
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
return runner_cls(env=env, config=runner_config)
|
||||
|
||||
|
||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags."""
|
||||
Task.ignore_requirements("torch")
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
training_name = choices.get("training", "ppo")
|
||||
|
||||
project = "RL-Framework"
|
||||
task_name = f"{env_name}-{runner_name}-{training_name}"
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
),
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# Execute remotely if requested and running locally
|
||||
if remote and task.running_locally():
|
||||
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
|
||||
# ClearML init — must happen before heavy work so remote execution
|
||||
# can take over early. The remote worker re-runs the full script;
|
||||
# execute_remotely() is a no-op on the worker side.
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
remote = training_dict.pop("remote", False)
|
||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
task.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
254
scripts/viz.py
Normal file
254
scripts/viz.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||
|
||||
Usage (simulation):
|
||||
mjpython scripts/viz.py env=rotary_cartpole
|
||||
mjpython scripts/viz.py env=cartpole +com=true
|
||||
|
||||
Usage (real hardware — digital twin):
|
||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial runner.port=/dev/ttyUSB0
|
||||
|
||||
Controls:
|
||||
Left/Right arrows — apply torque to first actuator
|
||||
R — reset environment
|
||||
Esc / close window — quit
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
_PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import numpy as np
|
||||
import structlog
|
||||
import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
_action_time = [0.0] # timestamp of last key press
|
||||
_reset_flag = [False]
|
||||
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||
if keycode == 263: # GLFW_KEY_LEFT
|
||||
_action_val[0] = -1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||
_action_val[0] = 1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
|
||||
|
||||
def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
"""Draw an arrow on the motor joint showing applied torque direction."""
|
||||
if abs(action_val) < 0.01 or model.nu == 0:
|
||||
return
|
||||
|
||||
# Get the body that the first actuator's joint belongs to
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
body_id = model.jnt_bodyid[jnt_id]
|
||||
|
||||
# Arrow origin: body position
|
||||
pos = data.xpos[body_id].copy()
|
||||
pos[2] += 0.02 # lift slightly above the body
|
||||
|
||||
# Arrow direction: along joint axis in world frame, scaled by action
|
||||
axis = data.xmat[body_id].reshape(3, 3) @ model.jnt_axis[jnt_id]
|
||||
arrow_len = 0.08 * action_val
|
||||
direction = axis * np.sign(arrow_len)
|
||||
|
||||
# Build rotation matrix: arrow rendered along local z-axis
|
||||
z = direction / (np.linalg.norm(direction) + 1e-8)
|
||||
up = np.array([0, 0, 1]) if abs(z[2]) < 0.99 else np.array([0, 1, 0])
|
||||
x = np.cross(up, z)
|
||||
x /= np.linalg.norm(x) + 1e-8
|
||||
y = np.cross(z, x)
|
||||
mat = np.column_stack([x, y, z]).flatten()
|
||||
|
||||
# Color: green = positive, red = negative
|
||||
rgba = np.array(
|
||||
[0.2, 0.8, 0.2, 0.8] if action_val > 0 else [0.8, 0.2, 0.2, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
geom = viewer.user_scn.geoms[viewer.user_scn.ngeom]
|
||||
mujoco.mjv_initGeom(
|
||||
geom,
|
||||
type=mujoco.mjtGeom.mjGEOM_ARROW,
|
||||
size=np.array([0.008, 0.008, abs(arrow_len)]),
|
||||
pos=pos,
|
||||
mat=mat,
|
||||
rgba=rgba,
|
||||
)
|
||||
viewer.user_scn.ngeom += 1
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
|
||||
if runner_name == "serial":
|
||||
_main_serial(cfg, env_name)
|
||||
else:
|
||||
_main_sim(cfg, env_name)
|
||||
|
||||
|
||||
def _main_sim(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Simulation visualization — step MuJoCo physics with keyboard control."""
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
model = runner._model
|
||||
data = runner._data[0]
|
||||
|
||||
# Control period
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
# Launch viewer
|
||||
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
|
||||
logger.info("viewer_started", env=env_name,
|
||||
controls="Left/Right arrows = torque, R = reset")
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
logger.info("reset")
|
||||
|
||||
# Step through runner
|
||||
action = torch.tensor([[action_val]])
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
|
||||
# Sync viewer with action arrow overlay
|
||||
mujoco.mj_forward(model, data)
|
||||
viewer.user_scn.ngeom = 0 # clear previous frame's overlays
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Print state
|
||||
if step % 25 == 0:
|
||||
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||
for i in range(model.njnt)}
|
||||
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action_val, 1), **joints)
|
||||
|
||||
# Real-time pacing
|
||||
time.sleep(dt_ctrl)
|
||||
step += 1
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||
"""Digital-twin visualization — mirror real hardware in MuJoCo viewer.
|
||||
|
||||
The MuJoCo model is loaded for rendering only. Joint positions are
|
||||
read from the ESP32 over serial and applied to the model each frame.
|
||||
Keyboard arrows send motor commands to the real robot.
|
||||
"""
|
||||
from src.runners.serial import SerialRunner, SerialRunnerConfig
|
||||
|
||||
env = build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
serial_runner = SerialRunner(
|
||||
env=env, config=SerialRunnerConfig(**runner_dict)
|
||||
)
|
||||
|
||||
# Load MuJoCo model for visualisation (same URDF the sim uses).
|
||||
serial_runner._ensure_viz_model()
|
||||
model = serial_runner._viz_model
|
||||
data = serial_runner._viz_data
|
||||
|
||||
with mujoco.viewer.launch_passive(
|
||||
model, data, key_callback=_key_callback
|
||||
) as viewer:
|
||||
# Show CoM / inertia if requested.
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
logger.info(
|
||||
"viewer_started",
|
||||
env=env_name,
|
||||
mode="serial (digital twin)",
|
||||
port=serial_runner.config.port,
|
||||
controls="Left/Right arrows = motor command, R = reset",
|
||||
)
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from keyboard callback.
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press.
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
logger.info("reset (drive-to-center + settle)")
|
||||
|
||||
# Send motor command to real hardware.
|
||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
||||
serial_runner._send(f"M{motor_speed}")
|
||||
|
||||
# Sync MuJoCo model with real sensor data.
|
||||
serial_runner._sync_viz()
|
||||
|
||||
# Render overlays and sync viewer.
|
||||
viewer.user_scn.ngeom = 0
|
||||
_add_action_arrow(viewer, model, data, action_val)
|
||||
viewer.sync()
|
||||
|
||||
# Real-time pacing (~50 Hz, matches serial dt).
|
||||
time.sleep(serial_runner.config.dt)
|
||||
|
||||
serial_runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user