"""RMA Phase 2: Train the adaptation module φ(history) → ẑ. Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder, and trains a HistoryEncoder (adaptation module) to predict the teacher's latent z from observation-action history using supervised MSE. Usage: python scripts/train_adaptation.py \ --checkpoint runs//checkpoints/agent_XXXXX.pt \ --env rotary_cartpole \ --robot-path assets/rotary_cartpole \ --num-envs 64 \ --iterations 2000 \ --lr 3e-4 """ import argparse import pathlib import sys _PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) import structlog import torch import tqdm from gymnasium import spaces from omegaconf import OmegaConf from src.core.registry import build_env from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig log = structlog.get_logger() def _load_teacher_checkpoint( path: str, obs_space: spaces.Space, act_space: spaces.Space, device: torch.device, raw_obs_dim: int, mu_dim: int, hidden_sizes: tuple[int, ...], latent_dim: int, ) -> SharedMLP: """Reconstruct the teacher SharedMLP and load saved weights.""" model = SharedMLP( observation_space=obs_space, action_space=act_space, device=device, hidden_sizes=hidden_sizes, rma_mode="teacher", raw_obs_dim=raw_obs_dim, mu_dim=mu_dim, latent_dim=latent_dim, ) ckpt = torch.load(path, map_location=device, weights_only=True) # skrl saves under "policy" key with state_dict. if "policy" in ckpt: model.load_state_dict(ckpt["policy"]) else: model.load_state_dict(ckpt) return model def _build_deploy_model( teacher: SharedMLP, obs_space: spaces.Space, act_space: spaces.Space, device: torch.device, raw_obs_dim: int, history_length: int, hidden_sizes: tuple[int, ...], latent_dim: int, ) -> SharedMLP: """Create a deploy-mode SharedMLP and copy backbone + heads from teacher.""" model = SharedMLP( observation_space=obs_space, action_space=act_space, device=device, hidden_sizes=hidden_sizes, rma_mode="deploy", raw_obs_dim=raw_obs_dim, history_length=history_length, latent_dim=latent_dim, ) # Copy backbone, policy head, value head from teacher. model.net.load_state_dict(teacher.net.state_dict()) model.mean_layer.load_state_dict(teacher.mean_layer.state_dict()) model.value_layer.load_state_dict(teacher.value_layer.state_dict()) model.log_std_parameter.data.copy_(teacher.log_std_parameter.data) return model def main() -> None: parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module") parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint") parser.add_argument("--env", default="rotary_cartpole") parser.add_argument("--robot-path", default="assets/rotary_cartpole") parser.add_argument("--num-envs", type=int, default=64) parser.add_argument("--iterations", type=int, default=2000) parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout") parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--latent-dim", type=int, default=8) parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128]) parser.add_argument("--history-length", type=int, default=10) parser.add_argument("--output", default="checkpoints/adaptation.pt") parser.add_argument("--device", default="cpu") args = parser.parse_args() device = torch.device(args.device) hidden_sizes = tuple(args.hidden_sizes) # ── Build env + runner (deploy mode with history + DR) ─────── env_cfg = OmegaConf.create({"env": { "robot_path": args.robot_path, }}) env = build_env(args.env, env_cfg) runner_cfg = MuJoCoRunnerConfig( num_envs=args.num_envs, device=args.device, history_length=args.history_length, rma_mode="deploy", domain_rand={ "qpos_noise_std": 0.01, "qvel_noise_std": 0.5, "action_delay_steps": [0, 2], "friction_scale": [0.6, 1.6], "damping_scale": [0.6, 1.6], "torque_scale": [0.85, 1.15], }, ) runner = MuJoCoRunner(env=env, config=runner_cfg) raw_obs_dim = env.observation_space.shape[0] act_dim = env.action_space.shape[0] mu_dim = runner.privileged_dim log.info( "adaptation_setup", raw_obs_dim=raw_obs_dim, act_dim=act_dim, mu_dim=mu_dim, latent_dim=args.latent_dim, history_length=args.history_length, ) # ── Load teacher & build deploy model ──────────────────────── # Teacher obs space: [raw_obs, μ] teacher_obs_space = spaces.Box( low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,), ) teacher = _load_teacher_checkpoint( path=args.checkpoint, obs_space=teacher_obs_space, act_space=env.action_space, device=device, raw_obs_dim=raw_obs_dim, mu_dim=mu_dim, hidden_sizes=hidden_sizes, latent_dim=args.latent_dim, ) teacher.eval() for p in teacher.parameters(): p.requires_grad_(False) # Deploy obs space: [raw_obs, history_flat] step_dim = raw_obs_dim + act_dim deploy_obs_space = spaces.Box( low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + args.history_length * step_dim,), ) deploy_model = _build_deploy_model( teacher=teacher, obs_space=deploy_obs_space, act_space=env.action_space, device=device, raw_obs_dim=raw_obs_dim, history_length=args.history_length, hidden_sizes=hidden_sizes, latent_dim=args.latent_dim, ) # Freeze everything except the adaptation module. for name, param in deploy_model.named_parameters(): if "adaptation_module" not in name: param.requires_grad_(False) optimizer = torch.optim.Adam( deploy_model.adaptation_module.parameters(), lr=args.lr, ) # ── Training loop ──────────────────────────────────────────── log.info("starting_adaptation_training", iterations=args.iterations) obs, _ = runner.reset() for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"): # Collect a rollout using the deploy model. z_targets: list[torch.Tensor] = [] z_preds: list[torch.Tensor] = [] for _step in range(args.rollout_steps): with torch.no_grad(): # Get action from deploy model (uses adaptation module). aug_obs = obs # already augmented by runner actions = deploy_model.act( {"states": aug_obs}, role="policy", )[0] obs, _, _, _, info = runner.step(actions) # Compute teacher's z from privileged μ. mu = info.get("privileged_obs") if mu is not None: z_target = teacher.env_encoder(mu) z_targets.append(z_target) # Compute adaptation module's ẑ from history. raw = aug_obs[:, :raw_obs_dim] hist_flat = aug_obs[:, raw_obs_dim:] history = hist_flat.reshape( -1, args.history_length, step_dim, ) z_pred = deploy_model.adaptation_module(history) z_preds.append(z_pred) if not z_targets: continue # Supervised update on adaptation module. z_target_batch = torch.cat(z_targets, dim=0).detach() z_pred_batch = torch.cat(z_preds, dim=0) # Re-compute z_pred with gradients (the ones above were no_grad). # We need to re-encode from stored data; instead, collect with grad: # Actually, z_preds were computed in no_grad. Let me re-collect # a fresh batch with gradients. obs_reset, _ = runner.reset() obs = obs_reset z_targets_grad: list[torch.Tensor] = [] z_preds_grad: list[torch.Tensor] = [] for _step in range(args.rollout_steps): with torch.no_grad(): aug_obs = obs actions = deploy_model.act( {"states": aug_obs}, role="policy", )[0] obs, _, _, _, info = runner.step(actions) mu = info.get("privileged_obs") if mu is not None: with torch.no_grad(): z_target = teacher.env_encoder(mu) z_targets_grad.append(z_target) # This time, compute z_pred WITH gradients. raw = aug_obs[:, :raw_obs_dim] hist_flat = aug_obs[:, raw_obs_dim:] history = hist_flat.reshape( -1, args.history_length, step_dim, ) z_pred = deploy_model.adaptation_module(history) z_preds_grad.append(z_pred) if not z_targets_grad: continue z_target_all = torch.cat(z_targets_grad, dim=0).detach() z_pred_all = torch.cat(z_preds_grad, dim=0) loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all) optimizer.zero_grad() loss.backward() optimizer.step() if iteration % 50 == 0: log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}") # ── Save adaptation weights ────────────────────────────────── out_path = pathlib.Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) # Save the full deploy model state dict. torch.save(deploy_model.state_dict(), out_path) log.info("adaptation_saved", path=str(out_path)) runner.close() if __name__ == "__main__": main()