🐛 bug fixes
This commit is contained in:
@@ -157,6 +157,14 @@ def _create_base_task(
|
|||||||
add_task_init_call=False,
|
add_task_init_call=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Explicitly set Hydra config-group choices so cloned tasks
|
||||||
|
# pick up the correct env / runner / training groups.
|
||||||
|
# Task.create() does not populate the Hydra parameter section
|
||||||
|
# because Hydra never actually runs during creation.
|
||||||
|
base_task.set_parameter("Hydra/env", env)
|
||||||
|
base_task.set_parameter("Hydra/runner", runner)
|
||||||
|
base_task.set_parameter("Hydra/training", training)
|
||||||
|
|
||||||
# Set docker config
|
# Set docker config
|
||||||
base_task.set_base_docker(
|
base_task.set_base_docker(
|
||||||
"registry.kube.optimize/worker-image:latest",
|
"registry.kube.optimize/worker-image:latest",
|
||||||
@@ -198,12 +206,12 @@ def main() -> None:
|
|||||||
help="Total HPO trial budget",
|
help="Total HPO trial budget",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-budget", type=int, default=3,
|
"--min-budget", type=int, default=50_000,
|
||||||
help="Minimum budget (epochs) per trial",
|
help="Minimum budget (total_timesteps) per trial",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-budget", type=int, default=81,
|
"--max-budget", type=int, default=500_000,
|
||||||
help="Maximum budget (epochs) for promoted trials",
|
help="Maximum budget (total_timesteps) for promoted trials",
|
||||||
)
|
)
|
||||||
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
parser.add_argument("--eta", type=int, default=3, help="Successive halving reduction factor")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -303,6 +311,7 @@ def main() -> None:
|
|||||||
pool_period_min=1,
|
pool_period_min=1,
|
||||||
time_limit_per_job=240, # 4 hours per trial max
|
time_limit_per_job=240, # 4 hours per trial max
|
||||||
eta=args.eta,
|
eta=args.eta,
|
||||||
|
budget_param_name="Hydra/training.total_timesteps",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send this HPO controller to a remote services worker
|
# Send this HPO controller to a remote services worker
|
||||||
|
|||||||
@@ -101,6 +101,12 @@ def main(cfg: DictConfig) -> None:
|
|||||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||||
task = _init_clearml(choices, remote=remote)
|
task = _init_clearml(choices, remote=remote)
|
||||||
|
|
||||||
|
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
|
||||||
|
# resume_from_task_id or any future additions)
|
||||||
|
import dataclasses as _dc
|
||||||
|
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||||
|
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "cartpole")
|
||||||
env = build_env(env_name, cfg)
|
env = build_env(env_name, cfg)
|
||||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||||
|
|||||||
@@ -175,6 +175,12 @@ class OptimizerSMAC(SearchStrategy):
|
|||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Configurable budget parameter name
|
||||||
|
# Default: Hydra/training.total_timesteps (RL-Framework convention)
|
||||||
|
self.budget_param_name = smac_kwargs.pop(
|
||||||
|
"budget_param_name", "Hydra/training.total_timesteps"
|
||||||
|
)
|
||||||
|
|
||||||
# build the Successive Halving intensifier (NOT Hyperband!)
|
# build the Successive Halving intensifier (NOT Hyperband!)
|
||||||
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
||||||
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
||||||
@@ -262,11 +268,11 @@ class OptimizerSMAC(SearchStrategy):
|
|||||||
else:
|
else:
|
||||||
param_value = v
|
param_value = v
|
||||||
clone.set_parameter(original_name, param_value)
|
clone.set_parameter(original_name, param_value)
|
||||||
# Override epochs budget if multi-fidelity
|
# Override budget parameter (e.g. total_timesteps) for multi-fidelity
|
||||||
if self.max_iterations != self.min_iterations:
|
if self.max_iterations != self.min_iterations:
|
||||||
clone.set_parameter("Hydra/training.max_epochs", int(budget))
|
clone.set_parameter(self.budget_param_name, int(budget))
|
||||||
else:
|
else:
|
||||||
clone.set_parameter("Hydra/training.max_epochs", int(self.max_iterations))
|
clone.set_parameter(self.budget_param_name, int(self.max_iterations))
|
||||||
|
|
||||||
# If we have a previous task, pass its ID so the worker can download the checkpoint
|
# If we have a previous task, pass its ID so the worker can download the checkpoint
|
||||||
if prev_task_id:
|
if prev_task_id:
|
||||||
|
|||||||
6
train.py
6
train.py
@@ -99,6 +99,12 @@ def main(cfg: DictConfig) -> None:
|
|||||||
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
|
||||||
task = _init_clearml(choices, remote=remote)
|
task = _init_clearml(choices, remote=remote)
|
||||||
|
|
||||||
|
# Drop keys not recognised by TrainerConfig (e.g. ClearML-injected
|
||||||
|
# resume_from_task_id or any future additions)
|
||||||
|
import dataclasses as _dc
|
||||||
|
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||||
|
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||||
|
|
||||||
env_name = choices.get("env", "cartpole")
|
env_name = choices.get("env", "cartpole")
|
||||||
env = build_env(env_name, cfg)
|
env = build_env(env_name, cfg)
|
||||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user