138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
|
|
|
Usage:
|
|
mjpython viz.py env=rotary_cartpole
|
|
mjpython viz.py env=cartpole +com=true
|
|
|
|
Controls:
|
|
Left/Right arrows — apply torque to first actuator
|
|
R — reset environment
|
|
Esc / close window — quit
|
|
"""
|
|
import math
|
|
import time
|
|
|
|
import hydra
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import structlog
|
|
import torch
|
|
from hydra.core.hydra_config import HydraConfig
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
|
|
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
|
|
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# ── registry (same as train.py) ──────────────────────────────────────
|
|
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
|
"cartpole": (CartPoleEnv, CartPoleConfig),
|
|
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
|
}
|
|
|
|
|
|
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
|
if env_name not in ENV_REGISTRY:
|
|
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
|
env_cls, config_cls = ENV_REGISTRY[env_name]
|
|
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
|
if "actuators" in env_dict:
|
|
for a in env_dict["actuators"]:
|
|
if "ctrl_range" in a:
|
|
a["ctrl_range"] = tuple(a["ctrl_range"])
|
|
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
|
return env_cls(config_cls(**env_dict))
|
|
|
|
|
|
# ── keyboard state ───────────────────────────────────────────────────
|
|
_action_val = [0.0] # mutable container shared with callback
|
|
_action_time = [0.0] # timestamp of last key press
|
|
_reset_flag = [False]
|
|
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
|
|
|
|
|
def _key_callback(keycode: int) -> None:
|
|
"""Called by MuJoCo on key press & repeat (not release)."""
|
|
if keycode == 263: # GLFW_KEY_LEFT
|
|
_action_val[0] = -1.0
|
|
_action_time[0] = time.time()
|
|
elif keycode == 262: # GLFW_KEY_RIGHT
|
|
_action_val[0] = 1.0
|
|
_action_time[0] = time.time()
|
|
elif keycode == 82: # GLFW_KEY_R
|
|
_reset_flag[0] = True
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
choices = HydraConfig.get().runtime.choices
|
|
env_name = choices.get("env", "cartpole")
|
|
|
|
# Build env + runner (single env for viz)
|
|
env = _build_env(env_name, cfg)
|
|
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
|
runner_dict["num_envs"] = 1
|
|
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
|
|
|
model = runner._model
|
|
data = runner._data[0]
|
|
|
|
# Control period
|
|
dt_ctrl = runner.config.dt * runner.config.substeps
|
|
|
|
# Launch viewer
|
|
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
|
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
|
show_com = cfg.get("com", False)
|
|
if show_com:
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
|
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
|
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
|
|
logger.info("viewer_started", env=env_name,
|
|
controls="Left/Right arrows = torque, R = reset")
|
|
|
|
while viewer.is_running():
|
|
# Read action from callback (expires after _ACTION_HOLD_S)
|
|
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
|
action_val = _action_val[0]
|
|
else:
|
|
action_val = 0.0
|
|
|
|
# Reset on R press
|
|
if _reset_flag[0]:
|
|
_reset_flag[0] = False
|
|
obs, _ = runner.reset()
|
|
step = 0
|
|
logger.info("reset")
|
|
|
|
# Step through runner
|
|
action = torch.tensor([[action_val]])
|
|
obs, reward, terminated, truncated, info = runner.step(action)
|
|
|
|
# Sync viewer
|
|
mujoco.mj_forward(model, data)
|
|
viewer.sync()
|
|
|
|
# Print state
|
|
if step % 25 == 0:
|
|
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
|
for i in range(model.njnt)}
|
|
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
|
action=round(action_val, 1), **joints)
|
|
|
|
# Real-time pacing
|
|
time.sleep(dt_ctrl)
|
|
step += 1
|
|
|
|
runner.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|