♻️ full agent refactor

This commit is contained in:
2026-06-10 21:15:34 +02:00
parent a98e86ef66
commit 1e0836e1bc
49 changed files with 1309 additions and 829 deletions

View File

@@ -75,14 +75,13 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
"""Return the history/adaptation encoder output dim, if present.
"""Return the history encoder output dim, if present.
Lets eval reconstruct an embedding policy without knowing the training
embedding_dim/latent_dim — read it straight from the saved weights.
embedding_dim — read it straight from the saved weights.
"""
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
if key in state_dict:
return state_dict[key].shape[0]
if "history_encoder.fc.weight" in state_dict:
return state_dict["history_encoder.fc.weight"].shape[0]
return None
@@ -92,14 +91,13 @@ def load_policy(
action_space: spaces.Space,
device: torch.device = torch.device("cpu"),
history_length: int = 0,
rma_mode: str = "none",
raw_obs_dim: int = 0,
) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
For DR + history-embedding policies (history_length > 0) or RMA deploy
policies (rma_mode="deploy"), the history/adaptation encoder must be
reconstructed too — its output dim is read back from the saved weights.
For DR + history-embedding policies (history_length > 0), the history
encoder is reconstructed too — its output dim is read back from the
saved weights.
Returns:
(model, state_preprocessor) ready for inference.
@@ -117,11 +115,9 @@ def load_policy(
action_space=action_space,
device=device,
hidden_sizes=hidden_sizes,
history_length=history_length,
rma_mode=rma_mode,
history_length=history_length if enc_out else 0,
raw_obs_dim=raw_obs_dim,
embedding_dim=enc_out or 32, # legacy "none" + history
latent_dim=enc_out or 8, # RMA deploy adaptation module
embedding_dim=enc_out or 32,
)
model.load_state_dict(ckpt["policy"])
model.eval()
@@ -189,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None)
@@ -222,7 +218,6 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device,
history_length=runner.config.history_length,
rma_mode=runner.config.rma_mode,
raw_obs_dim=runner.env.observation_space.shape[0],
)
@@ -311,7 +306,6 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
history_length=serial_runner.config.history_length,
rma_mode=serial_runner.config.rma_mode,
raw_obs_dim=serial_runner.env.observation_space.shape[0],
)
@@ -339,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
obs, _ = serial_runner.reset()
obs, _ = serial_runner.reset() # drives to center + settles
step = 0
episode += 1
episode_reward = 0.0
@@ -376,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1),
motor_enc=state["encoder_count"],
pend_deg=round(state["pendulum_angle"], 1),
motor_deg=round(math.degrees(state["motor_rad"]), 1),
pend_deg=round(math.degrees(state["pend_rad"]), 1),
)
# Check for safety / disconnection.