Files
RL-Sim-Framework/scripts/train_adaptation.py
Victor Mylle b37cd26690 feat: sim2real domain randomization + reward fixes for rotary cartpole
Close the sim2real gap for the Furuta pendulum (swings up but can't
balance on hardware). Root causes were (a) no domain randomization, so
the policy overfit one deterministic sim instance, and (b) reward design
flaws that produced degenerate policies.

Domain randomization (runner-level, backend-agnostic):
- BaseRunner: domain_rand config; per-env action-delay buffer (latency),
  Gaussian qpos/qvel sensor noise, per-env dynamics-scale sampling
  (friction/damping/torque), resampled per episode. Sensor noise per step.
- privileged_obs/privileged_dim expose normalized DR factors (mu) for RMA.
- step() now uses clean state for reward/termination, noisy state for the
  observation the policy sees.
- MuJoCoRunner: applies per-env friction/damping/torque scales.
- robot.py: compute_motor_force gains friction/damping scale args.
- Configs: DR blocks for mujoco (full) and mjx (delay+noise); clean
  defaults for mujoco_single/serial; noise/delay anchored to recordings.

Reward fixes (rotary_cartpole):
- Shift upright reward to [0,1] (was [-1,1]) + alive_bonus, so surviving
  always beats ending early (kills the "suicide into the limit" policy).
- Add balance_bonus * upright * stillness so reward requires upright AND
  near-zero pendulum velocity (kills the "spin in full loops" policy).

Deploy:
- eval.py load_policy reconstructs the history/adaptation encoder
  (auto-detects its dim from the checkpoint) so DR+embedding policies load.

Fixes:
- MuJoCoRunner._sim_reset referenced self._env (typo) -> self.env, which
  was breaking every rotary-cartpole reset.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-09 20:48:25 +02:00

297 lines
10 KiB
Python

"""RMA Phase 2: Train the adaptation module φ(history) → ẑ.
Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder,
and trains a HistoryEncoder (adaptation module) to predict the teacher's
latent z from observation-action history using supervised MSE.
Usage:
python scripts/train_adaptation.py \
--checkpoint runs/<run>/checkpoints/agent_XXXXX.pt \
--env rotary_cartpole \
--robot-path assets/rotary_cartpole \
--num-envs 64 \
--iterations 2000 \
--lr 3e-4
"""
import argparse
import pathlib
import sys
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import structlog
import torch
import tqdm
from gymnasium import spaces
from omegaconf import OmegaConf
from src.core.registry import build_env
from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
log = structlog.get_logger()
def _load_teacher_checkpoint(
path: str, obs_space: spaces.Space, act_space: spaces.Space,
device: torch.device, raw_obs_dim: int, mu_dim: int,
hidden_sizes: tuple[int, ...], latent_dim: int,
) -> SharedMLP:
"""Reconstruct the teacher SharedMLP and load saved weights."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="teacher",
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
latent_dim=latent_dim,
)
ckpt = torch.load(path, map_location=device, weights_only=True)
# skrl saves under "policy" key with state_dict.
if "policy" in ckpt:
model.load_state_dict(ckpt["policy"])
else:
model.load_state_dict(ckpt)
return model
def _build_deploy_model(
teacher: SharedMLP,
obs_space: spaces.Space,
act_space: spaces.Space,
device: torch.device,
raw_obs_dim: int,
history_length: int,
hidden_sizes: tuple[int, ...],
latent_dim: int,
) -> SharedMLP:
"""Create a deploy-mode SharedMLP and copy backbone + heads from teacher."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="deploy",
raw_obs_dim=raw_obs_dim,
history_length=history_length,
latent_dim=latent_dim,
)
# Copy backbone, policy head, value head from teacher.
model.net.load_state_dict(teacher.net.state_dict())
model.mean_layer.load_state_dict(teacher.mean_layer.state_dict())
model.value_layer.load_state_dict(teacher.value_layer.state_dict())
model.log_std_parameter.data.copy_(teacher.log_std_parameter.data)
return model
def main() -> None:
parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module")
parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint")
parser.add_argument("--env", default="rotary_cartpole")
parser.add_argument("--robot-path", default="assets/rotary_cartpole")
parser.add_argument("--num-envs", type=int, default=64)
parser.add_argument("--iterations", type=int, default=2000)
parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--latent-dim", type=int, default=8)
parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128])
parser.add_argument("--history-length", type=int, default=10)
parser.add_argument("--output", default="checkpoints/adaptation.pt")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
device = torch.device(args.device)
hidden_sizes = tuple(args.hidden_sizes)
# ── Build env + runner (deploy mode with history + DR) ───────
env_cfg = OmegaConf.create({"env": {
"robot_path": args.robot_path,
}})
env = build_env(args.env, env_cfg)
runner_cfg = MuJoCoRunnerConfig(
num_envs=args.num_envs,
device=args.device,
history_length=args.history_length,
rma_mode="deploy",
domain_rand={
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
},
)
runner = MuJoCoRunner(env=env, config=runner_cfg)
raw_obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
mu_dim = runner.privileged_dim
log.info(
"adaptation_setup",
raw_obs_dim=raw_obs_dim,
act_dim=act_dim,
mu_dim=mu_dim,
latent_dim=args.latent_dim,
history_length=args.history_length,
)
# ── Load teacher & build deploy model ────────────────────────
# Teacher obs space: [raw_obs, μ]
teacher_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,),
)
teacher = _load_teacher_checkpoint(
path=args.checkpoint,
obs_space=teacher_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
teacher.eval()
for p in teacher.parameters():
p.requires_grad_(False)
# Deploy obs space: [raw_obs, history_flat]
step_dim = raw_obs_dim + act_dim
deploy_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf,
shape=(raw_obs_dim + args.history_length * step_dim,),
)
deploy_model = _build_deploy_model(
teacher=teacher,
obs_space=deploy_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
history_length=args.history_length,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
# Freeze everything except the adaptation module.
for name, param in deploy_model.named_parameters():
if "adaptation_module" not in name:
param.requires_grad_(False)
optimizer = torch.optim.Adam(
deploy_model.adaptation_module.parameters(), lr=args.lr,
)
# ── Training loop ────────────────────────────────────────────
log.info("starting_adaptation_training", iterations=args.iterations)
obs, _ = runner.reset()
for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"):
# Collect a rollout using the deploy model.
z_targets: list[torch.Tensor] = []
z_preds: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
# Get action from deploy model (uses adaptation module).
aug_obs = obs # already augmented by runner
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
# Compute teacher's z from privileged μ.
mu = info.get("privileged_obs")
if mu is not None:
z_target = teacher.env_encoder(mu)
z_targets.append(z_target)
# Compute adaptation module's ẑ from history.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds.append(z_pred)
if not z_targets:
continue
# Supervised update on adaptation module.
z_target_batch = torch.cat(z_targets, dim=0).detach()
z_pred_batch = torch.cat(z_preds, dim=0)
# Re-compute z_pred with gradients (the ones above were no_grad).
# We need to re-encode from stored data; instead, collect with grad:
# Actually, z_preds were computed in no_grad. Let me re-collect
# a fresh batch with gradients.
obs_reset, _ = runner.reset()
obs = obs_reset
z_targets_grad: list[torch.Tensor] = []
z_preds_grad: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
aug_obs = obs
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
mu = info.get("privileged_obs")
if mu is not None:
with torch.no_grad():
z_target = teacher.env_encoder(mu)
z_targets_grad.append(z_target)
# This time, compute z_pred WITH gradients.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds_grad.append(z_pred)
if not z_targets_grad:
continue
z_target_all = torch.cat(z_targets_grad, dim=0).detach()
z_pred_all = torch.cat(z_preds_grad, dim=0)
loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if iteration % 50 == 0:
log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}")
# ── Save adaptation weights ──────────────────────────────────
out_path = pathlib.Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
# Save the full deploy model state dict.
torch.save(deploy_model.state_dict(), out_path)
log.info("adaptation_saved", path=str(out_path))
runner.close()
if __name__ == "__main__":
main()