"""Interactive visualization — control any env with keyboard in MuJoCo viewer. Usage: mjpython viz.py env=rotary_cartpole mjpython viz.py env=cartpole +com=true Controls: Left/Right arrows — apply torque to first actuator R — reset environment Esc / close window — quit """ import math import time import hydra import mujoco import mujoco.viewer import structlog import torch from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf from src.core.registry import build_env from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig logger = structlog.get_logger() # ── keyboard state ─────────────────────────────────────────────────── _action_val = [0.0] # mutable container shared with callback _action_time = [0.0] # timestamp of last key press _reset_flag = [False] _ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event def _key_callback(keycode: int) -> None: """Called by MuJoCo on key press & repeat (not release).""" if keycode == 263: # GLFW_KEY_LEFT _action_val[0] = -1.0 _action_time[0] = time.time() elif keycode == 262: # GLFW_KEY_RIGHT _action_val[0] = 1.0 _action_time[0] = time.time() elif keycode == 82: # GLFW_KEY_R _reset_flag[0] = True @hydra.main(version_base=None, config_path="configs", config_name="config") def main(cfg: DictConfig) -> None: choices = HydraConfig.get().runtime.choices env_name = choices.get("env", "cartpole") # Build env + runner (single env for viz) env = build_env(env_name, cfg) runner_dict = OmegaConf.to_container(cfg.runner, resolve=True) runner_dict["num_envs"] = 1 runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict)) model = runner._model data = runner._data[0] # Control period dt_ctrl = runner.config.dt * runner.config.substeps # Launch viewer with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer: # Show CoM / inertia if requested via Hydra override: viz.py +com=true show_com = cfg.get("com", False) if show_com: viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True obs, _ = runner.reset() step = 0 logger.info("viewer_started", env=env_name, controls="Left/Right arrows = torque, R = reset") while viewer.is_running(): # Read action from callback (expires after _ACTION_HOLD_S) if time.time() - _action_time[0] < _ACTION_HOLD_S: action_val = _action_val[0] else: action_val = 0.0 # Reset on R press if _reset_flag[0]: _reset_flag[0] = False obs, _ = runner.reset() step = 0 logger.info("reset") # Step through runner action = torch.tensor([[action_val]]) obs, reward, terminated, truncated, info = runner.step(action) # Sync viewer mujoco.mj_forward(model, data) viewer.sync() # Print state if step % 25 == 0: joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1) for i in range(model.njnt)} logger.debug("step", n=step, reward=round(reward.item(), 3), action=round(action_val, 1), **joints) # Real-time pacing time.sleep(dt_ctrl) step += 1 runner.close() if __name__ == "__main__": main()