♻️ full agent refactor
This commit is contained in:
@@ -75,14 +75,13 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
||||
|
||||
|
||||
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
|
||||
"""Return the history/adaptation encoder output dim, if present.
|
||||
"""Return the history encoder output dim, if present.
|
||||
|
||||
Lets eval reconstruct an embedding policy without knowing the training
|
||||
embedding_dim/latent_dim — read it straight from the saved weights.
|
||||
embedding_dim — read it straight from the saved weights.
|
||||
"""
|
||||
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
|
||||
if key in state_dict:
|
||||
return state_dict[key].shape[0]
|
||||
if "history_encoder.fc.weight" in state_dict:
|
||||
return state_dict["history_encoder.fc.weight"].shape[0]
|
||||
return None
|
||||
|
||||
|
||||
@@ -92,14 +91,13 @@ def load_policy(
|
||||
action_space: spaces.Space,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
history_length: int = 0,
|
||||
rma_mode: str = "none",
|
||||
raw_obs_dim: int = 0,
|
||||
) -> tuple[SharedMLP, RunningStandardScaler]:
|
||||
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
||||
|
||||
For DR + history-embedding policies (history_length > 0) or RMA deploy
|
||||
policies (rma_mode="deploy"), the history/adaptation encoder must be
|
||||
reconstructed too — its output dim is read back from the saved weights.
|
||||
For DR + history-embedding policies (history_length > 0), the history
|
||||
encoder is reconstructed too — its output dim is read back from the
|
||||
saved weights.
|
||||
|
||||
Returns:
|
||||
(model, state_preprocessor) ready for inference.
|
||||
@@ -117,11 +115,9 @@ def load_policy(
|
||||
action_space=action_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
history_length=history_length,
|
||||
rma_mode=rma_mode,
|
||||
history_length=history_length if enc_out else 0,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=enc_out or 32, # legacy "none" + history
|
||||
latent_dim=enc_out or 8, # RMA deploy adaptation module
|
||||
embedding_dim=enc_out or 32,
|
||||
)
|
||||
model.load_state_dict(ckpt["policy"])
|
||||
model.eval()
|
||||
@@ -189,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco_single")
|
||||
|
||||
checkpoint_path = cfg.get("checkpoint", None)
|
||||
@@ -222,7 +218,6 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, runner.observation_space, runner.action_space, device,
|
||||
history_length=runner.config.history_length,
|
||||
rma_mode=runner.config.rma_mode,
|
||||
raw_obs_dim=runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
@@ -311,7 +306,6 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
|
||||
history_length=serial_runner.config.history_length,
|
||||
rma_mode=serial_runner.config.rma_mode,
|
||||
raw_obs_dim=serial_runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
@@ -339,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
obs, _ = serial_runner.reset()
|
||||
obs, _ = serial_runner.reset() # drives to center + settles
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
@@ -376,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1),
|
||||
motor_enc=state["encoder_count"],
|
||||
pend_deg=round(state["pendulum_angle"], 1),
|
||||
motor_deg=round(math.degrees(state["motor_rad"]), 1),
|
||||
pend_deg=round(math.degrees(state["pend_rad"]), 1),
|
||||
)
|
||||
|
||||
# Check for safety / disconnection.
|
||||
|
||||
Reference in New Issue
Block a user