♻️ crazy refactor

This commit is contained in:
2026-03-11 22:52:01 +01:00
parent 35223b3560
commit 4115447022
34 changed files with 4255 additions and 102 deletions

View File

@@ -26,8 +26,10 @@ logger = structlog.get_logger()
# Imports are deferred so JAX is only loaded when runner=mjx is chosen.
RUNNER_REGISTRY: dict[str, tuple[str, str, str]] = {
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
"mujoco": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mujoco_single": ("src.runners.mujoco", "MuJoCoRunner", "MuJoCoRunnerConfig"),
"mjx": ("src.runners.mjx", "MJXRunner", "MJXRunnerConfig"),
"serial": ("src.runners.serial", "SerialRunner", "SerialRunnerConfig"),
}
@@ -94,6 +96,7 @@ def main(cfg: DictConfig) -> None:
# execute_remotely() is a no-op on the worker side.
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
remote = training_dict.pop("remote", False)
training_dict.pop("hpo", None) # HPO range metadata — not a TrainerConfig field
task = _init_clearml(choices, remote=remote)
env_name = choices.get("env", "cartpole")