✨ better robot joint loading
This commit is contained in:
32
train.py
32
train.py
@@ -13,39 +13,13 @@ from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import BaseEnv, BaseEnvConfig, ActuatorConfig
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.registry import build_env
|
||||
from src.core.runner import BaseRunner
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── env registry ──────────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
"""Instantiate the right env + config from the Hydra config-group name."""
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
|
||||
# Convert actuator dicts → ActuatorConfig objects
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── runner registry ───────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (RunnerClass, ConfigClass)
|
||||
@@ -123,7 +97,7 @@ def main(cfg: DictConfig) -> None:
|
||||
task = _init_clearml(choices, remote=remote)
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env = _build_env(env_name, cfg)
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
Reference in New Issue
Block a user