import os import pathlib import sys # Headless rendering: use OSMesa on Linux servers (must be set before mujoco import) if sys.platform == "linux" and "DISPLAY" not in os.environ: os.environ.setdefault("MUJOCO_GL", "osmesa") import hydra import hydra.utils as hydra_utils import structlog from clearml import Task from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig from src.core.runner import BaseRunner from src.envs.cartpole import CartPoleEnv, CartPoleConfig from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig from src.training.trainer import Trainer, TrainerConfig logger = structlog.get_logger() # ── env registry ────────────────────────────────────────────────────── # Maps Hydra config-group name → (EnvClass, ConfigClass) ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = { "cartpole": (CartPoleEnv, CartPoleConfig), "rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig), } def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv: """Instantiate the right env + config from the Hydra config-group name.""" if env_name not in ENV_REGISTRY: raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}") env_cls, config_cls = ENV_REGISTRY[env_name] env_dict = OmegaConf.to_container(cfg.env, resolve=True) # Convert actuator dicts → ActuatorConfig objects if "actuators" in env_dict: for a in env_dict["actuators"]: if "ctrl_range" in a: a["ctrl_range"] = tuple(a["ctrl_range"]) env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]] return env_cls(config_cls(**env_dict)) # ── runner registry ─────────────────────────────────────────────────── # Maps Hydra config-group name → (RunnerClass, ConfigClass) # Imports are deferred so JAX is only loaded when runner=mjx is chosen. RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = { "mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"), "mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"), } def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner: """Instantiate the right runner from the Hydra config-group name.""" if runner_name not in RUNNER_REGISTRY: raise ValueError( f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}" ) module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name] import importlib mod = importlib.import_module(module_path) runner_cls = getattr(mod, cls_name) config_cls = getattr(mod, cfg_cls_name) runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True)) return runner_cls(env=env, config=runner_config) def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task: """Initialize ClearML task with project structure and tags. Project: RL-Trainings/ (e.g. RL-Trainings/Rotary Cartpole) Tags: env, runner, training algo choices from Hydra. """ Task.ignore_requirements("torch") env_name = choices.get("env", "cartpole") runner_name = choices.get("runner", "mujoco") training_name = choices.get("training", "ppo") project = "RL-Framework" task_name = f"{env_name}-{runner_name}-{training_name}" tags = [env_name, runner_name, training_name] task = Task.init(project_name=project, task_name=task_name, tags=tags) task.set_base_docker( "registry.kube.optimize/worker-image:latest", docker_setup_bash_script=( "apt-get update && apt-get install -y --no-install-recommends " "libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* " "&& pip install 'jax[cuda12]' mujoco-mjx" ), ) req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt" task.set_packages(str(req_file)) # Execute remotely if requested and running locally if remote and task.running_locally(): logger.info("executing_task_remotely", queue="gpu-queue") task.execute_remotely(queue_name="gpu-queue", exit_process=True) return task @hydra.main(version_base=None, config_path="configs", config_name="config") def main(cfg: DictConfig) -> None: choices = HydraConfig.get().runtime.choices # ClearML init — must happen before heavy work so remote execution # can take over early. The remote worker re-runs the full script; # execute_remotely() is a no-op on the worker side. training_dict = OmegaConf.to_container(cfg.training, resolve=True) remote = training_dict.pop("remote", False) task = _init_clearml(choices, remote=remote) env_name = choices.get("env", "cartpole") env = _build_env(env_name, cfg) runner = _build_runner(choices.get("runner", "mujoco"), env, cfg) trainer_config = TrainerConfig(**training_dict) trainer = Trainer(runner=runner, config=trainer_config) try: trainer.train() finally: trainer.close() task.close() if __name__ == "__main__": main()