better robot joint loading

This commit is contained in:
2026-03-09 22:17:28 +01:00
parent 9be07d9186
commit 70cd2cdd7d
13 changed files with 215 additions and 128 deletions

25
viz.py
View File

@@ -20,32 +20,11 @@ import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
from src.core.registry import build_env
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
logger = structlog.get_logger()
# ── registry (same as train.py) ──────────────────────────────────────
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
"cartpole": (CartPoleEnv, CartPoleConfig),
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
}
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
if env_name not in ENV_REGISTRY:
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
env_cls, config_cls = ENV_REGISTRY[env_name]
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
if "actuators" in env_dict:
for a in env_dict["actuators"]:
if "ctrl_range" in a:
a["ctrl_range"] = tuple(a["ctrl_range"])
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
return env_cls(config_cls(**env_dict))
# ── keyboard state ───────────────────────────────────────────────────
_action_val = [0.0] # mutable container shared with callback
@@ -72,7 +51,7 @@ def main(cfg: DictConfig) -> None:
env_name = choices.get("env", "cartpole")
# Build env + runner (single env for viz)
env = _build_env(env_name, cfg)
env = build_env(env_name, cfg)
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
runner_dict["num_envs"] = 1
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))