✨ update urdf and dependencies
This commit is contained in:
137
viz.py
Normal file
137
viz.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Interactive visualization — control any env with keyboard in MuJoCo viewer.
|
||||
|
||||
Usage:
|
||||
mjpython viz.py env=rotary_cartpole
|
||||
mjpython viz.py env=cartpole +com=true
|
||||
|
||||
Controls:
|
||||
Left/Right arrows — apply torque to first actuator
|
||||
R — reset environment
|
||||
Esc / close window — quit
|
||||
"""
|
||||
import math
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import mujoco
|
||||
import mujoco.viewer
|
||||
import structlog
|
||||
import torch
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from src.core.env import ActuatorConfig, BaseEnv, BaseEnvConfig
|
||||
from src.envs.cartpole import CartPoleConfig, CartPoleEnv
|
||||
from src.envs.rotary_cartpole import RotaryCartPoleConfig, RotaryCartPoleEnv
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# ── registry (same as train.py) ──────────────────────────────────────
|
||||
ENV_REGISTRY: dict[str, tuple[type[BaseEnv], type[BaseEnvConfig]]] = {
|
||||
"cartpole": (CartPoleEnv, CartPoleConfig),
|
||||
"rotary_cartpole": (RotaryCartPoleEnv, RotaryCartPoleConfig),
|
||||
}
|
||||
|
||||
|
||||
def _build_env(env_name: str, cfg: DictConfig) -> BaseEnv:
|
||||
if env_name not in ENV_REGISTRY:
|
||||
raise ValueError(f"Unknown env '{env_name}'. Registered: {list(ENV_REGISTRY)}")
|
||||
env_cls, config_cls = ENV_REGISTRY[env_name]
|
||||
env_dict = OmegaConf.to_container(cfg.env, resolve=True)
|
||||
if "actuators" in env_dict:
|
||||
for a in env_dict["actuators"]:
|
||||
if "ctrl_range" in a:
|
||||
a["ctrl_range"] = tuple(a["ctrl_range"])
|
||||
env_dict["actuators"] = [ActuatorConfig(**a) for a in env_dict["actuators"]]
|
||||
return env_cls(config_cls(**env_dict))
|
||||
|
||||
|
||||
# ── keyboard state ───────────────────────────────────────────────────
|
||||
_action_val = [0.0] # mutable container shared with callback
|
||||
_action_time = [0.0] # timestamp of last key press
|
||||
_reset_flag = [False]
|
||||
_ACTION_HOLD_S = 0.25 # seconds the action stays active after last key event
|
||||
|
||||
|
||||
def _key_callback(keycode: int) -> None:
|
||||
"""Called by MuJoCo on key press & repeat (not release)."""
|
||||
if keycode == 263: # GLFW_KEY_LEFT
|
||||
_action_val[0] = -1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 262: # GLFW_KEY_RIGHT
|
||||
_action_val[0] = 1.0
|
||||
_action_time[0] = time.time()
|
||||
elif keycode == 82: # GLFW_KEY_R
|
||||
_reset_flag[0] = True
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
|
||||
# Build env + runner (single env for viz)
|
||||
env = _build_env(env_name, cfg)
|
||||
runner_dict = OmegaConf.to_container(cfg.runner, resolve=True)
|
||||
runner_dict["num_envs"] = 1
|
||||
runner = MuJoCoRunner(env=env, config=MuJoCoRunnerConfig(**runner_dict))
|
||||
|
||||
model = runner._model
|
||||
data = runner._data[0]
|
||||
|
||||
# Control period
|
||||
dt_ctrl = runner.config.dt * runner.config.substeps
|
||||
|
||||
# Launch viewer
|
||||
with mujoco.viewer.launch_passive(model, data, key_callback=_key_callback) as viewer:
|
||||
# Show CoM / inertia if requested via Hydra override: viz.py +com=true
|
||||
show_com = cfg.get("com", False)
|
||||
if show_com:
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_COM] = True
|
||||
viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_INERTIA] = True
|
||||
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
|
||||
logger.info("viewer_started", env=env_name,
|
||||
controls="Left/Right arrows = torque, R = reset")
|
||||
|
||||
while viewer.is_running():
|
||||
# Read action from callback (expires after _ACTION_HOLD_S)
|
||||
if time.time() - _action_time[0] < _ACTION_HOLD_S:
|
||||
action_val = _action_val[0]
|
||||
else:
|
||||
action_val = 0.0
|
||||
|
||||
# Reset on R press
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
obs, _ = runner.reset()
|
||||
step = 0
|
||||
logger.info("reset")
|
||||
|
||||
# Step through runner
|
||||
action = torch.tensor([[action_val]])
|
||||
obs, reward, terminated, truncated, info = runner.step(action)
|
||||
|
||||
# Sync viewer
|
||||
mujoco.mj_forward(model, data)
|
||||
viewer.sync()
|
||||
|
||||
# Print state
|
||||
if step % 25 == 0:
|
||||
joints = {model.jnt(i).name: round(math.degrees(data.qpos[i]), 1)
|
||||
for i in range(model.njnt)}
|
||||
logger.debug("step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action_val, 1), **joints)
|
||||
|
||||
# Real-time pacing
|
||||
time.sleep(dt_ctrl)
|
||||
step += 1
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user