update hpo

This commit is contained in:
2026-03-11 23:28:39 +01:00
parent 23801857f4
commit 3b2d6d08f9
3 changed files with 75 additions and 145 deletions

View File

@@ -132,6 +132,21 @@ def _build_hyper_parameters(config: dict) -> list:
return params
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
"""Flatten a nested dict into dot-separated keys.
Example: {"a": {"b": 1}} → {"a.b": 1}
"""
items = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(_flatten_dict(v, new_key, sep=sep))
else:
items[new_key] = v
return items
def _create_base_task(
env: str, runner: str, training: str, queue: str
) -> str:
@@ -139,6 +154,8 @@ def _create_base_task(
Uses Task.create() to register a task pointing at scripts/train.py
with the correct Hydra overrides. The HPO optimizer will clone this.
The full resolved OmegaConf config is attached as Hydra/* parameters
so cloned trial tasks inherit the complete configuration.
"""
script_path = str(Path(__file__).resolve().parent / "train.py")
project_root = str(Path(__file__).resolve().parent.parent)
@@ -157,14 +174,44 @@ def _create_base_task(
add_task_init_call=False,
)
# Explicitly set Hydra config-group choices so cloned tasks
# pick up the correct env / runner / training groups.
# Task.create() does not populate the Hydra parameter section
# because Hydra never actually runs during creation.
# ── Attach full resolved OmegaConf config ─────────────────────
# ClearML's Hydra binding normally does this when the script runs,
# but Task.create() never executes Hydra. We replicate the binding
# manually: config group choices + all resolved values.
base_task.set_parameter("Hydra/env", env)
base_task.set_parameter("Hydra/runner", runner)
base_task.set_parameter("Hydra/training", training)
# Load and resolve the full config for each group
configs_dir = Path(__file__).resolve().parent.parent / "configs"
for section, name in [("training", training), ("env", env), ("runner", runner)]:
cfg_path = configs_dir / section / f"{name}.yaml"
if not cfg_path.exists():
continue
cfg = OmegaConf.load(cfg_path)
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
if "defaults" in cfg:
defaults = OmegaConf.to_container(cfg.defaults)
base_cfg = OmegaConf.create({})
for d in defaults:
if isinstance(d, str):
base_path = configs_dir / section / f"{d}.yaml"
if base_path.exists():
loaded = OmegaConf.load(base_path)
base_cfg = OmegaConf.merge(base_cfg, loaded)
cfg_no_defaults = {
k: v for k, v in OmegaConf.to_container(cfg).items()
if k != "defaults"
}
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
resolved = OmegaConf.to_container(cfg, resolve=True)
# Remove hpo metadata — not a real config value
resolved.pop("hpo", None)
flat = _flatten_dict(resolved)
for key, value in flat.items():
base_task.set_parameter(f"Hydra/{section}.{key}", value)
# Set docker config
base_task.set_base_docker(
"registry.kube.optimize/worker-image:latest",
@@ -267,15 +314,6 @@ def main() -> None:
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
return
# ── Create or reuse base task ─────────────────────────────────
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
# ── Initialize ClearML HPO task ───────────────────────────────
Task.ignore_requirements("torch")
task = Task.init(
@@ -295,6 +333,24 @@ def main() -> None:
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
task.set_packages(str(req_file))
# ── Create or reuse base task ─────────────────────────────────
# Store the base_task_id on the HPO task so that when the services
# worker re-runs this script it reuses the same base task instead
# of creating a duplicate.
if args.base_task_id:
base_task_id = args.base_task_id
logger.info("using_existing_base_task", task_id=base_task_id)
else:
existing = task.get_parameter("General/base_task_id")
if existing:
base_task_id = existing
logger.info("reusing_base_task_from_param", task_id=base_task_id)
else:
base_task_id = _create_base_task(
args.env, args.runner, args.training, args.queue
)
task.set_parameter("General/base_task_id", base_task_id)
# ── Build objective metric ────────────────────────────────────
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
objective_title = args.objective_metric