✨ update urdf and dependencies
This commit is contained in:
16
train.py
16
train.py
@@ -1,4 +1,8 @@
|
||||
import pathlib
|
||||
|
||||
import hydra
|
||||
import hydra.utils as hydra_utils
|
||||
import structlog
|
||||
from clearml import Task
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@@ -9,6 +13,8 @@ from src.envs.rotary_cartpole import RotaryCartPoleEnv, RotaryCartPoleConfig
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
from src.training.trainer import Trainer, TrainerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── env registry ──────────────────────────────────────────────────────
|
||||
# Maps Hydra config-group name → (EnvClass, ConfigClass)
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
@@ -52,9 +58,15 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
tags = [env_name, runner_name, training_name]
|
||||
|
||||
task = Task.init(project_name=project, task_name=task_name, tags=tags)
|
||||
task.set_base_docker("registry.kube.optimize/worker-image:latest")
|
||||
|
||||
if remote:
|
||||
task.execute_remotely(queue_name="default")
|
||||
req_file = pathlib.Path(hydra_utils.get_original_cwd()) / "requirements.txt"
|
||||
task.set_packages(str(req_file))
|
||||
|
||||
# Execute remotely if requested and running locally
|
||||
if remote and task.running_locally():
|
||||
logger.info("executing_task_remotely", queue="gpu-queue")
|
||||
task.execute_remotely(queue_name="gpu-queue", exit_process=True)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
Reference in New Issue
Block a user