Close the sim2real gap for the Furuta pendulum (swings up but can't balance on hardware). Root causes were (a) no domain randomization, so the policy overfit one deterministic sim instance, and (b) reward design flaws that produced degenerate policies. Domain randomization (runner-level, backend-agnostic): - BaseRunner: domain_rand config; per-env action-delay buffer (latency), Gaussian qpos/qvel sensor noise, per-env dynamics-scale sampling (friction/damping/torque), resampled per episode. Sensor noise per step. - privileged_obs/privileged_dim expose normalized DR factors (mu) for RMA. - step() now uses clean state for reward/termination, noisy state for the observation the policy sees. - MuJoCoRunner: applies per-env friction/damping/torque scales. - robot.py: compute_motor_force gains friction/damping scale args. - Configs: DR blocks for mujoco (full) and mjx (delay+noise); clean defaults for mujoco_single/serial; noise/delay anchored to recordings. Reward fixes (rotary_cartpole): - Shift upright reward to [0,1] (was [-1,1]) + alive_bonus, so surviving always beats ending early (kills the "suicide into the limit" policy). - Add balance_bonus * upright * stillness so reward requires upright AND near-zero pendulum velocity (kills the "spin in full loops" policy). Deploy: - eval.py load_policy reconstructs the history/adaptation encoder (auto-detects its dim from the checkpoint) so DR+embedding policies load. Fixes: - MuJoCoRunner._sim_reset referenced self._env (typo) -> self.env, which was breaking every rotary-cartpole reset. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
297 lines
10 KiB
Python
297 lines
10 KiB
Python
"""RMA Phase 2: Train the adaptation module φ(history) → ẑ.
|
|
|
|
Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder,
|
|
and trains a HistoryEncoder (adaptation module) to predict the teacher's
|
|
latent z from observation-action history using supervised MSE.
|
|
|
|
Usage:
|
|
python scripts/train_adaptation.py \
|
|
--checkpoint runs/<run>/checkpoints/agent_XXXXX.pt \
|
|
--env rotary_cartpole \
|
|
--robot-path assets/rotary_cartpole \
|
|
--num-envs 64 \
|
|
--iterations 2000 \
|
|
--lr 3e-4
|
|
"""
|
|
|
|
import argparse
|
|
import pathlib
|
|
import sys
|
|
|
|
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
import structlog
|
|
import torch
|
|
import tqdm
|
|
from gymnasium import spaces
|
|
from omegaconf import OmegaConf
|
|
|
|
from src.core.registry import build_env
|
|
from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder
|
|
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
|
|
|
log = structlog.get_logger()
|
|
|
|
|
|
def _load_teacher_checkpoint(
|
|
path: str, obs_space: spaces.Space, act_space: spaces.Space,
|
|
device: torch.device, raw_obs_dim: int, mu_dim: int,
|
|
hidden_sizes: tuple[int, ...], latent_dim: int,
|
|
) -> SharedMLP:
|
|
"""Reconstruct the teacher SharedMLP and load saved weights."""
|
|
model = SharedMLP(
|
|
observation_space=obs_space,
|
|
action_space=act_space,
|
|
device=device,
|
|
hidden_sizes=hidden_sizes,
|
|
rma_mode="teacher",
|
|
raw_obs_dim=raw_obs_dim,
|
|
mu_dim=mu_dim,
|
|
latent_dim=latent_dim,
|
|
)
|
|
ckpt = torch.load(path, map_location=device, weights_only=True)
|
|
# skrl saves under "policy" key with state_dict.
|
|
if "policy" in ckpt:
|
|
model.load_state_dict(ckpt["policy"])
|
|
else:
|
|
model.load_state_dict(ckpt)
|
|
return model
|
|
|
|
|
|
def _build_deploy_model(
|
|
teacher: SharedMLP,
|
|
obs_space: spaces.Space,
|
|
act_space: spaces.Space,
|
|
device: torch.device,
|
|
raw_obs_dim: int,
|
|
history_length: int,
|
|
hidden_sizes: tuple[int, ...],
|
|
latent_dim: int,
|
|
) -> SharedMLP:
|
|
"""Create a deploy-mode SharedMLP and copy backbone + heads from teacher."""
|
|
model = SharedMLP(
|
|
observation_space=obs_space,
|
|
action_space=act_space,
|
|
device=device,
|
|
hidden_sizes=hidden_sizes,
|
|
rma_mode="deploy",
|
|
raw_obs_dim=raw_obs_dim,
|
|
history_length=history_length,
|
|
latent_dim=latent_dim,
|
|
)
|
|
|
|
# Copy backbone, policy head, value head from teacher.
|
|
model.net.load_state_dict(teacher.net.state_dict())
|
|
model.mean_layer.load_state_dict(teacher.mean_layer.state_dict())
|
|
model.value_layer.load_state_dict(teacher.value_layer.state_dict())
|
|
model.log_std_parameter.data.copy_(teacher.log_std_parameter.data)
|
|
|
|
return model
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module")
|
|
parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint")
|
|
parser.add_argument("--env", default="rotary_cartpole")
|
|
parser.add_argument("--robot-path", default="assets/rotary_cartpole")
|
|
parser.add_argument("--num-envs", type=int, default=64)
|
|
parser.add_argument("--iterations", type=int, default=2000)
|
|
parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout")
|
|
parser.add_argument("--lr", type=float, default=3e-4)
|
|
parser.add_argument("--latent-dim", type=int, default=8)
|
|
parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128])
|
|
parser.add_argument("--history-length", type=int, default=10)
|
|
parser.add_argument("--output", default="checkpoints/adaptation.pt")
|
|
parser.add_argument("--device", default="cpu")
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device(args.device)
|
|
hidden_sizes = tuple(args.hidden_sizes)
|
|
|
|
# ── Build env + runner (deploy mode with history + DR) ───────
|
|
env_cfg = OmegaConf.create({"env": {
|
|
"robot_path": args.robot_path,
|
|
}})
|
|
env = build_env(args.env, env_cfg)
|
|
|
|
runner_cfg = MuJoCoRunnerConfig(
|
|
num_envs=args.num_envs,
|
|
device=args.device,
|
|
history_length=args.history_length,
|
|
rma_mode="deploy",
|
|
domain_rand={
|
|
"qpos_noise_std": 0.01,
|
|
"qvel_noise_std": 0.5,
|
|
"action_delay_steps": [0, 2],
|
|
"friction_scale": [0.6, 1.6],
|
|
"damping_scale": [0.6, 1.6],
|
|
"torque_scale": [0.85, 1.15],
|
|
},
|
|
)
|
|
runner = MuJoCoRunner(env=env, config=runner_cfg)
|
|
|
|
raw_obs_dim = env.observation_space.shape[0]
|
|
act_dim = env.action_space.shape[0]
|
|
mu_dim = runner.privileged_dim
|
|
|
|
log.info(
|
|
"adaptation_setup",
|
|
raw_obs_dim=raw_obs_dim,
|
|
act_dim=act_dim,
|
|
mu_dim=mu_dim,
|
|
latent_dim=args.latent_dim,
|
|
history_length=args.history_length,
|
|
)
|
|
|
|
# ── Load teacher & build deploy model ────────────────────────
|
|
# Teacher obs space: [raw_obs, μ]
|
|
teacher_obs_space = spaces.Box(
|
|
low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,),
|
|
)
|
|
|
|
teacher = _load_teacher_checkpoint(
|
|
path=args.checkpoint,
|
|
obs_space=teacher_obs_space,
|
|
act_space=env.action_space,
|
|
device=device,
|
|
raw_obs_dim=raw_obs_dim,
|
|
mu_dim=mu_dim,
|
|
hidden_sizes=hidden_sizes,
|
|
latent_dim=args.latent_dim,
|
|
)
|
|
teacher.eval()
|
|
for p in teacher.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
# Deploy obs space: [raw_obs, history_flat]
|
|
step_dim = raw_obs_dim + act_dim
|
|
deploy_obs_space = spaces.Box(
|
|
low=-torch.inf, high=torch.inf,
|
|
shape=(raw_obs_dim + args.history_length * step_dim,),
|
|
)
|
|
|
|
deploy_model = _build_deploy_model(
|
|
teacher=teacher,
|
|
obs_space=deploy_obs_space,
|
|
act_space=env.action_space,
|
|
device=device,
|
|
raw_obs_dim=raw_obs_dim,
|
|
history_length=args.history_length,
|
|
hidden_sizes=hidden_sizes,
|
|
latent_dim=args.latent_dim,
|
|
)
|
|
|
|
# Freeze everything except the adaptation module.
|
|
for name, param in deploy_model.named_parameters():
|
|
if "adaptation_module" not in name:
|
|
param.requires_grad_(False)
|
|
|
|
optimizer = torch.optim.Adam(
|
|
deploy_model.adaptation_module.parameters(), lr=args.lr,
|
|
)
|
|
|
|
# ── Training loop ────────────────────────────────────────────
|
|
log.info("starting_adaptation_training", iterations=args.iterations)
|
|
obs, _ = runner.reset()
|
|
|
|
for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"):
|
|
# Collect a rollout using the deploy model.
|
|
z_targets: list[torch.Tensor] = []
|
|
z_preds: list[torch.Tensor] = []
|
|
|
|
for _step in range(args.rollout_steps):
|
|
with torch.no_grad():
|
|
# Get action from deploy model (uses adaptation module).
|
|
aug_obs = obs # already augmented by runner
|
|
actions = deploy_model.act(
|
|
{"states": aug_obs}, role="policy",
|
|
)[0]
|
|
|
|
obs, _, _, _, info = runner.step(actions)
|
|
|
|
# Compute teacher's z from privileged μ.
|
|
mu = info.get("privileged_obs")
|
|
if mu is not None:
|
|
z_target = teacher.env_encoder(mu)
|
|
z_targets.append(z_target)
|
|
|
|
# Compute adaptation module's ẑ from history.
|
|
raw = aug_obs[:, :raw_obs_dim]
|
|
hist_flat = aug_obs[:, raw_obs_dim:]
|
|
history = hist_flat.reshape(
|
|
-1, args.history_length, step_dim,
|
|
)
|
|
z_pred = deploy_model.adaptation_module(history)
|
|
z_preds.append(z_pred)
|
|
|
|
if not z_targets:
|
|
continue
|
|
|
|
# Supervised update on adaptation module.
|
|
z_target_batch = torch.cat(z_targets, dim=0).detach()
|
|
z_pred_batch = torch.cat(z_preds, dim=0)
|
|
|
|
# Re-compute z_pred with gradients (the ones above were no_grad).
|
|
# We need to re-encode from stored data; instead, collect with grad:
|
|
# Actually, z_preds were computed in no_grad. Let me re-collect
|
|
# a fresh batch with gradients.
|
|
obs_reset, _ = runner.reset()
|
|
obs = obs_reset
|
|
|
|
z_targets_grad: list[torch.Tensor] = []
|
|
z_preds_grad: list[torch.Tensor] = []
|
|
|
|
for _step in range(args.rollout_steps):
|
|
with torch.no_grad():
|
|
aug_obs = obs
|
|
actions = deploy_model.act(
|
|
{"states": aug_obs}, role="policy",
|
|
)[0]
|
|
obs, _, _, _, info = runner.step(actions)
|
|
mu = info.get("privileged_obs")
|
|
|
|
if mu is not None:
|
|
with torch.no_grad():
|
|
z_target = teacher.env_encoder(mu)
|
|
z_targets_grad.append(z_target)
|
|
|
|
# This time, compute z_pred WITH gradients.
|
|
raw = aug_obs[:, :raw_obs_dim]
|
|
hist_flat = aug_obs[:, raw_obs_dim:]
|
|
history = hist_flat.reshape(
|
|
-1, args.history_length, step_dim,
|
|
)
|
|
z_pred = deploy_model.adaptation_module(history)
|
|
z_preds_grad.append(z_pred)
|
|
|
|
if not z_targets_grad:
|
|
continue
|
|
|
|
z_target_all = torch.cat(z_targets_grad, dim=0).detach()
|
|
z_pred_all = torch.cat(z_preds_grad, dim=0)
|
|
|
|
loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if iteration % 50 == 0:
|
|
log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}")
|
|
|
|
# ── Save adaptation weights ──────────────────────────────────
|
|
out_path = pathlib.Path(args.output)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save the full deploy model state dict.
|
|
torch.save(deploy_model.state_dict(), out_path)
|
|
log.info("adaptation_saved", path=str(out_path))
|
|
|
|
runner.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|