✨ initial commit
This commit is contained in:
47
train.py
Normal file
47
train.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import hydra
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.envs.cartpole import CartPoleEnv, CartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
from src.core.env import ActuatorConfig
|
||||
|
||||
|
||||
def _build_env_config(cfg: DictConfig) -> CartPoleConfig:
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return CartPoleConfig(**env_dict)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
env_config = _build_env_config(cfg)
|
||||
runner_config = MuJoCoRunnerConfig(**OmegaConf.to_container(cfg.runner, resolve=True))
|
||||
|
||||
training_dict = OmegaConf.to_container(cfg.training, resolve=True)
|
||||
# Build ClearML task name dynamically from Hydra config group choices
|
||||
if not training_dict.get("clearml_task"):
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "env")
|
||||
runner_name = choices.get("runner", "runner")
|
||||
training_name = choices.get("training", "algo")
|
||||
training_dict["clearml_task"] = f"{env_name}-{runner_name}-{training_name}"
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
env = CartPoleEnv(env_config)
|
||||
runner = MuJoCoRunner(env=env, config=runner_config)
|
||||
trainer = Trainer(runner=runner, config=trainer_config)
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
finally:
|
||||
trainer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user