update hpo

This commit is contained in:
2026-03-11 23:28:39 +01:00
parent 23801857f4
commit 3b2d6d08f9
3 changed files with 75 additions and 145 deletions

View File

@@ -132,6 +132,21 @@ def _build_hyper_parameters(config: dict) -> list:
return params return params
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
"""Flatten a nested dict into dot-separated keys.
Example: {"a": {"b": 1}} → {"a.b": 1}
"""
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(_flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
def _create_base_task( def _create_base_task(
env: str, runner: str, training: str, queue: str env: str, runner: str, training: str, queue: str
) -> str: ) -> str:
@@ -139,6 +154,8 @@ def _create_base_task(
Uses Task.create() to register a task pointing at scripts/train.py Uses Task.create() to register a task pointing at scripts/train.py
with the correct Hydra overrides. The HPO optimizer will clone this. with the correct Hydra overrides. The HPO optimizer will clone this.
The full resolved OmegaConf config is attached as Hydra/* parameters
so cloned trial tasks inherit the complete configuration.
""" """
script_path = str(Path(__file__).resolve().parent / "train.py") script_path = str(Path(__file__).resolve().parent / "train.py")
project_root = str(Path(__file__).resolve().parent.parent) project_root = str(Path(__file__).resolve().parent.parent)
@@ -157,14 +174,44 @@ def _create_base_task(
add_task_init_call=False, add_task_init_call=False,
) )
# Explicitly set Hydra config-group choices so cloned tasks # ── Attach full resolved OmegaConf config ─────────────────────
# pick up the correct env / runner / training groups. # ClearML's Hydra binding normally does this when the script runs,
# Task.create() does not populate the Hydra parameter section # but Task.create() never executes Hydra. We replicate the binding
# because Hydra never actually runs during creation. # manually: config group choices + all resolved values.
base_task.set_parameter("Hydra/env", env) base_task.set_parameter("Hydra/env", env)
base_task.set_parameter("Hydra/runner", runner) base_task.set_parameter("Hydra/runner", runner)
base_task.set_parameter("Hydra/training", training) base_task.set_parameter("Hydra/training", training)
# Load and resolve the full config for each group
configs_dir = Path(__file__).resolve().parent.parent / "configs"
for section, name in [("training", training), ("env", env), ("runner", runner)]:
cfg_path = configs_dir / section / f"{name}.yaml"
if not cfg_path.exists():
continue
cfg = OmegaConf.load(cfg_path)
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
if "defaults" in cfg:
defaults = OmegaConf.to_container(cfg.defaults)
base_cfg = OmegaConf.create({})
for d in defaults:
if isinstance(d, str):
base_path = configs_dir / section / f"{d}.yaml"
if base_path.exists():
loaded = OmegaConf.load(base_path)
base_cfg = OmegaConf.merge(base_cfg, loaded)
cfg_no_defaults = {
k: v for k, v in OmegaConf.to_container(cfg).items()
if k != "defaults"
}
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
resolved = OmegaConf.to_container(cfg, resolve=True)
# Remove hpo metadata — not a real config value
resolved.pop("hpo", None)
flat = _flatten_dict(resolved)
for key, value in flat.items():
base_task.set_parameter(f"Hydra/{section}.{key}", value)
# Set docker config # Set docker config
base_task.set_base_docker( base_task.set_base_docker(
"registry.kube.optimize/worker-image:latest", "registry.kube.optimize/worker-image:latest",
@@ -267,15 +314,6 @@ def main() -> None:
print(f"\nObjective: {args.objective_metric} ({objective_sign})") print(f"\nObjective: {args.objective_metric} ({objective_sign})")
return return
# ── Create or reuse base task ─────────────────────────────────
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
# ── Initialize ClearML HPO task ─────────────────────────────── # ── Initialize ClearML HPO task ───────────────────────────────
Task.ignore_requirements("torch") Task.ignore_requirements("torch")
task = Task.init( task = Task.init(
@@ -295,6 +333,24 @@ def main() -> None:
req_file = Path(__file__).resolve().parent.parent / "requirements.txt" req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
task.set_packages(str(req_file)) task.set_packages(str(req_file))
# ── Create or reuse base task ─────────────────────────────────
# Store the base_task_id on the HPO task so that when the services
# worker re-runs this script it reuses the same base task instead
# of creating a duplicate.
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
existing = task.get_parameter("General/base_task_id")
if existing:
base_task_id = existing
logger.info("reusing_base_task_from_param", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
task.set_parameter("General/base_task_id", base_task_id)
# ── Build objective metric ──────────────────────────────────── # ── Build objective metric ────────────────────────────────────
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default # skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
objective_title = args.objective_metric objective_title = args.objective_metric

View File

@@ -181,6 +181,12 @@ class OptimizerSMAC(SearchStrategy):
"budget_param_name", "Hydra/training.total_timesteps" "budget_param_name", "Hydra/training.total_timesteps"
) )
# Pop our custom kwargs BEFORE passing smac_kwargs to SuccessiveHalving
self.max_consecutive_failures = int(
smac_kwargs.pop("max_consecutive_failures", 3)
)
self._consecutive_failures = 0
# build the Successive Halving intensifier (NOT Hyperband!) # build the Successive Halving intensifier (NOT Hyperband!)
# Hyperband runs multiple brackets with different starting budgets - wasteful # Hyperband runs multiple brackets with different starting budgets - wasteful
# Successive Halving: ALL configs start at min_budget, only best get promoted # Successive Halving: ALL configs start at min_budget, only best get promoted
@@ -204,12 +210,6 @@ class OptimizerSMAC(SearchStrategy):
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf") self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes) self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
# Consecutive-failure abort: stop HPO if N trials in a row crash
self.max_consecutive_failures = int(
smac_kwargs.pop("max_consecutive_failures", 3)
)
self._consecutive_failures = 0
# Checkpoint continuation tracking: config_key -> {budget: task_id} # Checkpoint continuation tracking: config_key -> {budget: task_id}
# Used to find the previous task's checkpoint when promoting a config # Used to find the previous task's checkpoint when promoting a config
self.config_to_tasks = {} # config_key -> {budget: task_id} self.config_to_tasks = {} # config_key -> {budget: task_id}

126
train.py
View File

@@ -1,126 +0,0 @@
import os
import pathlib
import sys
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
if sys.platform == "linux":
os.environ.setdefault("MUJOCO_GL", "osmesa")
import hydra
import hydra.utils as hydra_utils
import structlog
from clearml import Task
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import BaseEnv
from src.core.registry import build_env
from src.core.runner import BaseRunner
from src.training.trainer import Trainer, TrainerConfig
logger = structlog.get_logger()
# ── runner registry ───────────────────────────────────────────────────
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
}
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
"""Instantiate the right runner from the Hydra config-group name."""
if runner_name not in RUNNER_REGISTRY:
raise ValueError(
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
)
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
import importlib
mod = importlib.import_module(module_path)
runner_cls = getattr(mod, cls_name)
config_cls = getattr(mod, cfg_cls_name)
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
return runner_cls(env=env, config=runner_config)
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags.
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
Tags: env, runner, training algo choices from Hydra.
"""
Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
project = "RL-Framework"
task_name = f"{env_name}-{runner_name}-{training_name}"
tags = [env_name, runner_name, training_name]
task = Task.init(project_name=project, task_name=task_name, tags=tags)
task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]' mujoco-mjx PyOpenGL PyOpenGL-accelerate"
),
docker_arguments=[
"-e", "MUJOCO_GL=osmesa",
],
)
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
task.set_packages(str(req_file))
# Execute remotely if requested and running locally
if remote and task.running_locally():
logger.info("executing_task_remotely", queue="gpu-queue")
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
return task
@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
# ClearML init — must happen before heavy work so remote execution
# can take over early. The remote worker re-runs the full script;
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
task = _init_clearml(choices, remote=remote)
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
# resume_from_task_id or any future additions)
import dataclasses as _dc
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
env_name = choices.get("env", "cartpole")
env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)
trainer = Trainer(runner=runner, config=trainer_config)
try:
trainer.train()
finally:
trainer.close()
task.close()
if __name__ == "__main__":
main()