✨ update hpo
This commit is contained in:
@@ -132,6 +132,21 @@ def _build_hyper_parameters(config: dict) -> list:
|
||||
return params
|
||||
|
||||
|
||||
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
|
||||
"""Flatten a nested dict into dot-separated keys.
|
||||
|
||||
Example: {"a": {"b": 1}} → {"a.b": 1}
|
||||
"""
|
||||
items = {}
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.update(_flatten_dict(v, new_key, sep=sep))
|
||||
else:
|
||||
items[new_key] = v
|
||||
return items
|
||||
|
||||
|
||||
def _create_base_task(
|
||||
env: str, runner: str, training: str, queue: str
|
||||
) -> str:
|
||||
@@ -139,6 +154,8 @@ def _create_base_task(
|
||||
|
||||
Uses Task.create() to register a task pointing at scripts/train.py
|
||||
with the correct Hydra overrides. The HPO optimizer will clone this.
|
||||
The full resolved OmegaConf config is attached as Hydra/* parameters
|
||||
so cloned trial tasks inherit the complete configuration.
|
||||
"""
|
||||
script_path = str(Path(__file__).resolve().parent / "train.py")
|
||||
project_root = str(Path(__file__).resolve().parent.parent)
|
||||
@@ -157,14 +174,44 @@ def _create_base_task(
|
||||
add_task_init_call=False,
|
||||
)
|
||||
|
||||
# Explicitly set Hydra config-group choices so cloned tasks
|
||||
# pick up the correct env / runner / training groups.
|
||||
# Task.create() does not populate the Hydra parameter section
|
||||
# because Hydra never actually runs during creation.
|
||||
# ── Attach full resolved OmegaConf config ─────────────────────
|
||||
# ClearML's Hydra binding normally does this when the script runs,
|
||||
# but Task.create() never executes Hydra. We replicate the binding
|
||||
# manually: config group choices + all resolved values.
|
||||
base_task.set_parameter("Hydra/env", env)
|
||||
base_task.set_parameter("Hydra/runner", runner)
|
||||
base_task.set_parameter("Hydra/training", training)
|
||||
|
||||
# Load and resolve the full config for each group
|
||||
configs_dir = Path(__file__).resolve().parent.parent / "configs"
|
||||
for section, name in [("training", training), ("env", env), ("runner", runner)]:
|
||||
cfg_path = configs_dir / section / f"{name}.yaml"
|
||||
if not cfg_path.exists():
|
||||
continue
|
||||
cfg = OmegaConf.load(cfg_path)
|
||||
# Handle Hydra defaults: inheritance (e.g. ppo_single → ppo)
|
||||
if "defaults" in cfg:
|
||||
defaults = OmegaConf.to_container(cfg.defaults)
|
||||
base_cfg = OmegaConf.create({})
|
||||
for d in defaults:
|
||||
if isinstance(d, str):
|
||||
base_path = configs_dir / section / f"{d}.yaml"
|
||||
if base_path.exists():
|
||||
loaded = OmegaConf.load(base_path)
|
||||
base_cfg = OmegaConf.merge(base_cfg, loaded)
|
||||
cfg_no_defaults = {
|
||||
k: v for k, v in OmegaConf.to_container(cfg).items()
|
||||
if k != "defaults"
|
||||
}
|
||||
cfg = OmegaConf.merge(base_cfg, OmegaConf.create(cfg_no_defaults))
|
||||
|
||||
resolved = OmegaConf.to_container(cfg, resolve=True)
|
||||
# Remove hpo metadata — not a real config value
|
||||
resolved.pop("hpo", None)
|
||||
flat = _flatten_dict(resolved)
|
||||
for key, value in flat.items():
|
||||
base_task.set_parameter(f"Hydra/{section}.{key}", value)
|
||||
|
||||
# Set docker config
|
||||
base_task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
@@ -267,15 +314,6 @@ def main() -> None:
|
||||
print(f"\nObjective: {args.objective_metric} ({objective_sign})")
|
||||
return
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
args.env, args.runner, args.training, args.queue
|
||||
)
|
||||
|
||||
# ── Initialize ClearML HPO task ───────────────────────────────
|
||||
Task.ignore_requirements("torch")
|
||||
task = Task.init(
|
||||
@@ -295,6 +333,24 @@ def main() -> None:
|
||||
req_file = Path(__file__).resolve().parent.parent / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# ── Create or reuse base task ─────────────────────────────────
|
||||
# Store the base_task_id on the HPO task so that when the services
|
||||
# worker re-runs this script it reuses the same base task instead
|
||||
# of creating a duplicate.
|
||||
if args.base_task_id:
|
||||
base_task_id = args.base_task_id
|
||||
logger.info("using_existing_base_task", task_id=base_task_id)
|
||||
else:
|
||||
existing = task.get_parameter("General/base_task_id")
|
||||
if existing:
|
||||
base_task_id = existing
|
||||
logger.info("reusing_base_task_from_param", task_id=base_task_id)
|
||||
else:
|
||||
base_task_id = _create_base_task(
|
||||
args.env, args.runner, args.training, args.queue
|
||||
)
|
||||
task.set_parameter("General/base_task_id", base_task_id)
|
||||
|
||||
# ── Build objective metric ────────────────────────────────────
|
||||
# skrl's SequentialTrainer logs "Reward / Total reward (mean)" by default
|
||||
objective_title = args.objective_metric
|
||||
|
||||
Reference in New Issue
Block a user