101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
import pathlib
|
|
|
|
import hydra
|
|
import hydra.utils as hydra_utils
|
|
import structlog
|
|
from clearml import Task
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
|
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
|
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
from src.training.trainer import Trainer, TrainerConfig
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# ── env registry ──────────────────────────────────────────────────────
|
|
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
|
"cartpole": (CartPoleEnv, CartPoleConfig),
|
|
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
|
}
|
|
|
|
|
|
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
|
"""Instantiate the right env + config from the Hydra config-group name."""
|
|
if env_name not in ENV_REGISTRY:
|
|
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
|
|
|
env_cls, config_cls = ENV_REGISTRY[env_name]
|
|
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
|
|
|
# Convert actuator dicts → ActuatorConfig objects
|
|
if "actuators" in env_dict:
|
|
for a in env_dict["actuators"]:
|
|
if "ctrl_range" in a:
|
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
|
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
|
|
|
return env_cls(config_cls(**env_dict))
|
|
|
|
|
|
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|
"""Initialize ClearML task with project structure and tags.
|
|
|
|
Project: RL-Trainings/<EnvName> (e.g. RL-Trainings/Rotary Cartpole)
|
|
Tags: env, runner, training algo choices from Hydra.
|
|
"""
|
|
Task.ignore_requirements("torch")
|
|
|
|
env_name = choices.get("env", "cartpole")
|
|
runner_name = choices.get("runner", "mujoco")
|
|
training_name = choices.get("training", "ppo")
|
|
|
|
project = "RL-Framework"
|
|
task_name = f"{env_name}-{runner_name}-{training_name}"
|
|
tags = [env_name, runner_name, training_name]
|
|
|
|
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
|
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
|
|
|
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
|
task.set_packages(str(req_file))
|
|
|
|
# Execute remotely if requested and running locally
|
|
if remote and task.running_locally():
|
|
logger.info("executing_task_remotely", queue="gpu-queue")
|
|
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
|
|
|
return task
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
choices = HydraConfig.get().runtime.choices
|
|
|
|
# ClearML init — must happen before heavy work so remote execution
|
|
# can take over early. The remote worker re-runs the full script;
|
|
# execute_remotely() is a no-op on the worker side.
|
|
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
|
remote = training_dict.pop("remote", False)
|
|
task = _init_clearml(choices, remote=remote)
|
|
|
|
env_name = choices.get("env", "cartpole")
|
|
env = _build_env(env_name, cfg)
|
|
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
|
trainer_config = TrainerConfig(**training_dict)
|
|
|
|
runner = MuJoCoRunner(env=env, config=runner_config)
|
|
trainer = Trainer(runner=runner, config=trainer_config)
|
|
|
|
try:
|
|
trainer.train()
|
|
finally:
|
|
trainer.close()
|
|
task.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |