✨ add mjx runner
This commit is contained in:
48
train.py
48
train.py
@@ -1,4 +1,10 @@
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import)
|
||||
if sys.platform == "linux" and "DISPLAY" not in os.environ:
|
||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
@@ -8,9 +14,9 @@ from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||
from src.core.runner import BaseRunner
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -41,6 +47,33 @@ def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
|
||||
|
||||
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
|
||||
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
|
||||
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
|
||||
}
|
||||
|
||||
|
||||
def _build_runner(runner_name: str, env: BaseEnv, cfg: DictConfig) -> BaseRunner:
|
||||
"""Instantiate the right runner from the Hydra config-group name."""
|
||||
if runner_name not in RUNNER_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown runner '{runner_name}'. Registered: {list(RUNNER_REGISTRY)}"
|
||||
)
|
||||
module_path, cls_name, cfg_cls_name = RUNNER_REGISTRY[runner_name]
|
||||
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
runner_cls = getattr(mod, cls_name)
|
||||
config_cls = getattr(mod, cfg_cls_name)
|
||||
|
||||
runner_config = config_cls(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
return runner_cls(env=env, config=runner_config)
|
||||
|
||||
|
||||
def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags.
|
||||
|
||||
@@ -58,7 +91,14 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
||||
task.set_base_docker(
|
||||
"registry.kube.optimize/worker-image:latest",
|
||||
docker_setup_bash_script=(
|
||||
"apt-get update && apt-get install -y --no-install-recommends "
|
||||
"libosmesa6 libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
||||
"&& pip install 'jax[cuda12]' mujoco-mjx"
|
||||
),
|
||||
)
|
||||
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
@@ -84,10 +124,8 @@ def main(cfg: DictConfig) -> None:
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = _build_env(env_name, cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user