⚗️ experimenting training runs
This commit is contained in:
@@ -8,15 +8,17 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
|
|
||||||
hidden_sizes: [256, 256]
|
hidden_sizes: [256, 256]
|
||||||
total_timesteps: 500000
|
total_timesteps: 1000000
|
||||||
learning_epochs: 5
|
learning_epochs: 10
|
||||||
learning_rate: 0.001
|
learning_rate: 0.0003
|
||||||
entropy_loss_scale: 0.0001
|
entropy_loss_scale: 0.01
|
||||||
log_interval: 1024
|
rollout_steps: 2048
|
||||||
|
mini_batches: 8
|
||||||
|
log_interval: 2048
|
||||||
checkpoint_interval: 10000
|
checkpoint_interval: 10000
|
||||||
initial_log_std: -0.5
|
initial_log_std: -0.5
|
||||||
min_log_std: -4.0
|
min_log_std: -4.0
|
||||||
max_log_std: 0.0
|
max_log_std: 2.0
|
||||||
|
|
||||||
record_video_every: 50000
|
record_video_every: 50000
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,17 @@ search ranges from the Hydra config's `training.hpo` and `env.hpo` blocks,
|
|||||||
and launches SMAC3 Successive Halving optimization.
|
and launches SMAC3 Successive Halving optimization.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python scripts/hpo.py \
|
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single
|
||||||
--env rotary_cartpole \
|
|
||||||
--runner mujoco_single \
|
# With HPO-specific options:
|
||||||
--training ppo_single \
|
python scripts/hpo.py env=rotary_cartpole runner=mujoco_single training=ppo_single \\
|
||||||
--queue gpu-queue
|
--queue gpu-queue --total-trials 100
|
||||||
|
|
||||||
# Or use an existing base task:
|
# Or use an existing base task:
|
||||||
python scripts/hpo.py --base-task-id <TASK_ID>
|
python scripts/hpo.py --base-task-id <TASK_ID>
|
||||||
|
|
||||||
|
# Dry run (print search space only):
|
||||||
|
python scripts/hpo.py env=rotary_cartpole --dry-run
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -233,9 +236,33 @@ def _create_base_task(
|
|||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_overrides(argv: list[str]) -> dict[str, str]:
|
||||||
|
"""Parse Hydra-style key=value overrides from argv.
|
||||||
|
|
||||||
|
Returns a dict of parsed key-value pairs. Unknown args (--flags)
|
||||||
|
are left in argv for argparse to handle.
|
||||||
|
"""
|
||||||
|
overrides = {}
|
||||||
|
remaining = []
|
||||||
|
for arg in argv:
|
||||||
|
if "=" in arg and not arg.startswith("-"):
|
||||||
|
key, value = arg.split("=", 1)
|
||||||
|
overrides[key] = value
|
||||||
|
else:
|
||||||
|
remaining.append(arg)
|
||||||
|
argv.clear()
|
||||||
|
argv.extend(remaining)
|
||||||
|
return overrides
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
# First pass: extract Hydra-style key=value overrides from sys.argv
|
||||||
|
raw_args = sys.argv[1:]
|
||||||
|
overrides = _parse_overrides(raw_args)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Hyperparameter optimization for RL-Framework"
|
description="Hyperparameter optimization for RL-Framework",
|
||||||
|
usage="%(prog)s env=<ENV> runner=<RUNNER> training=<TRAINING> [options]",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-task-id",
|
"--base-task-id",
|
||||||
@@ -243,9 +270,6 @@ def main() -> None:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Existing ClearML task ID to use as base (skip auto-creation)",
|
help="Existing ClearML task ID to use as base (skip auto-creation)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--env", type=str, default="rotary_cartpole")
|
|
||||||
parser.add_argument("--runner", type=str, default="mujoco_single")
|
|
||||||
parser.add_argument("--training", type=str, default="ppo_single")
|
|
||||||
parser.add_argument("--queue", type=str, default="gpu-queue")
|
parser.add_argument("--queue", type=str, default="gpu-queue")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-concurrent", type=int, default=2,
|
"--max-concurrent", type=int, default=2,
|
||||||
@@ -292,12 +316,17 @@ def main() -> None:
|
|||||||
"--dry-run", action="store_true",
|
"--dry-run", action="store_true",
|
||||||
help="Print search space and exit without running",
|
help="Print search space and exit without running",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args(raw_args)
|
||||||
|
|
||||||
|
# Resolve env/runner/training from Hydra-style overrides (same as train.py)
|
||||||
|
env = overrides.get("env", "rotary_cartpole")
|
||||||
|
runner = overrides.get("runner", "mujoco_single")
|
||||||
|
training = overrides.get("training", "ppo_single")
|
||||||
|
|
||||||
objective_sign = "min" if args.minimize else "max"
|
objective_sign = "min" if args.minimize else "max"
|
||||||
|
|
||||||
# ── Load config and build search space ────────────────────────
|
# ── Load config and build search space ────────────────────────
|
||||||
config = _load_hydra_config(args.env, args.runner, args.training)
|
config = _load_hydra_config(env, runner, training)
|
||||||
hyper_parameters = _build_hyper_parameters(config)
|
hyper_parameters = _build_hyper_parameters(config)
|
||||||
|
|
||||||
if not hyper_parameters:
|
if not hyper_parameters:
|
||||||
@@ -318,7 +347,7 @@ def main() -> None:
|
|||||||
Task.ignore_requirements("torch")
|
Task.ignore_requirements("torch")
|
||||||
task = Task.init(
|
task = Task.init(
|
||||||
project_name="RL-Framework",
|
project_name="RL-Framework",
|
||||||
task_name=f"HPO {args.env}-{args.runner}-{args.training}",
|
task_name=f"HPO {env}-{runner}-{training}",
|
||||||
task_type=Task.TaskTypes.optimizer,
|
task_type=Task.TaskTypes.optimizer,
|
||||||
reuse_last_task_id=False,
|
reuse_last_task_id=False,
|
||||||
)
|
)
|
||||||
@@ -347,7 +376,7 @@ def main() -> None:
|
|||||||
logger.info("reusing_base_task_from_param", task_id=base_task_id)
|
logger.info("reusing_base_task_from_param", task_id=base_task_id)
|
||||||
else:
|
else:
|
||||||
base_task_id = _create_base_task(
|
base_task_id = _create_base_task(
|
||||||
args.env, args.runner, args.training, args.queue
|
env, runner, training, args.queue
|
||||||
)
|
)
|
||||||
task.set_parameter("General/base_task_id", base_task_id)
|
task.set_parameter("General/base_task_id", base_task_id)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user