♻️ full agent refactor

This commit is contained in:
2026-06-10 21:15:34 +02:00
parent a98e86ef66
commit 1e0836e1bc
49 changed files with 1309 additions and 829 deletions

View File

@@ -75,14 +75,13 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
"""Return the history/adaptation encoder output dim, if present.
"""Return the history encoder output dim, if present.
Lets eval reconstruct an embedding policy without knowing the training
embedding_dim/latent_dim — read it straight from the saved weights.
embedding_dim — read it straight from the saved weights.
"""
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
if key in state_dict:
return state_dict[key].shape[0]
if "history_encoder.fc.weight" in state_dict:
return state_dict["history_encoder.fc.weight"].shape[0]
return None
@@ -92,14 +91,13 @@ def load_policy(
action_space: spaces.Space,
device: torch.device = torch.device("cpu"),
history_length: int = 0,
rma_mode: str = "none",
raw_obs_dim: int = 0,
) -> tuple[SharedMLP, RunningStandardScaler]:
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
For DR + history-embedding policies (history_length > 0) or RMA deploy
policies (rma_mode="deploy"), the history/adaptation encoder must be
reconstructed too — its output dim is read back from the saved weights.
For DR + history-embedding policies (history_length > 0), the history
encoder is reconstructed too — its output dim is read back from the
saved weights.
Returns:
(model, state_preprocessor) ready for inference.
@@ -117,11 +115,9 @@ def load_policy(
action_space=action_space,
device=device,
hidden_sizes=hidden_sizes,
history_length=history_length,
rma_mode=rma_mode,
history_length=history_length if enc_out else 0,
raw_obs_dim=raw_obs_dim,
embedding_dim=enc_out or 32, # legacy "none" + history
latent_dim=enc_out or 8, # RMA deploy adaptation module
embedding_dim=enc_out or 32,
)
model.load_state_dict(ckpt["policy"])
model.eval()
@@ -189,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco_single")
checkpoint_path = cfg.get("checkpoint", None)
@@ -222,7 +218,6 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy(
checkpoint_path, runner.observation_space, runner.action_space, device,
history_length=runner.config.history_length,
rma_mode=runner.config.rma_mode,
raw_obs_dim=runner.env.observation_space.shape[0],
)
@@ -311,7 +306,6 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
model, preprocessor = load_policy(
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
history_length=serial_runner.config.history_length,
rma_mode=serial_runner.config.rma_mode,
raw_obs_dim=serial_runner.env.observation_space.shape[0],
)
@@ -339,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
if _reset_flag[0]:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
obs, _ = serial_runner.reset()
obs, _ = serial_runner.reset() # drives to center + settles
step = 0
episode += 1
episode_reward = 0.0
@@ -376,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
"step", n=step, reward=round(reward.item(), 3),
action=round(action[0, 0].item(), 2),
ep_reward=round(episode_reward, 1),
motor_enc=state["encoder_count"],
pend_deg=round(state["pendulum_angle"], 1),
motor_deg=round(math.degrees(state["motor_rad"]), 1),
pend_deg=round(math.degrees(state["pend_rad"]), 1),
)
# Check for safety / disconnection.

View File

@@ -352,7 +352,7 @@ def main() -> None:
reuse_last_task_id=False,
)
task.set_base_docker(
docker_image="registry.kube.optimize/worker-image:latest",
docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
docker_arguments=[
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",

View File

@@ -63,7 +63,7 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"""Initialize ClearML task with project structure and tags."""
Task.ignore_requirements("torch")
env_name = choices.get("env", "cartpole")
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco")
training_name = choices.get("training", "ppo")
@@ -113,7 +113,7 @@ def main(cfg: DictConfig) -> None:
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
env_name = choices.get("env", "cartpole")
env_name = choices.get("env", "rotary_cartpole")
env = build_env(env_name, cfg)
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
trainer_config = TrainerConfig(**training_dict)

View File

@@ -1,296 +0,0 @@
"""RMA Phase 2: Train the adaptation module φ(history) → ẑ.
Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder,
and trains a HistoryEncoder (adaptation module) to predict the teacher's
latent z from observation-action history using supervised MSE.
Usage:
python scripts/train_adaptation.py \
--checkpoint runs/<run>/checkpoints/agent_XXXXX.pt \
--env rotary_cartpole \
--robot-path assets/rotary_cartpole \
--num-envs 64 \
--iterations 2000 \
--lr 3e-4
"""
import argparse
import pathlib
import sys
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
import structlog
import torch
import tqdm
from gymnasium import spaces
from omegaconf import OmegaConf
from src.core.registry import build_env
from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
log = structlog.get_logger()
def _load_teacher_checkpoint(
path: str, obs_space: spaces.Space, act_space: spaces.Space,
device: torch.device, raw_obs_dim: int, mu_dim: int,
hidden_sizes: tuple[int, ...], latent_dim: int,
) -> SharedMLP:
"""Reconstruct the teacher SharedMLP and load saved weights."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="teacher",
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
latent_dim=latent_dim,
)
ckpt = torch.load(path, map_location=device, weights_only=True)
# skrl saves under "policy" key with state_dict.
if "policy" in ckpt:
model.load_state_dict(ckpt["policy"])
else:
model.load_state_dict(ckpt)
return model
def _build_deploy_model(
teacher: SharedMLP,
obs_space: spaces.Space,
act_space: spaces.Space,
device: torch.device,
raw_obs_dim: int,
history_length: int,
hidden_sizes: tuple[int, ...],
latent_dim: int,
) -> SharedMLP:
"""Create a deploy-mode SharedMLP and copy backbone + heads from teacher."""
model = SharedMLP(
observation_space=obs_space,
action_space=act_space,
device=device,
hidden_sizes=hidden_sizes,
rma_mode="deploy",
raw_obs_dim=raw_obs_dim,
history_length=history_length,
latent_dim=latent_dim,
)
# Copy backbone, policy head, value head from teacher.
model.net.load_state_dict(teacher.net.state_dict())
model.mean_layer.load_state_dict(teacher.mean_layer.state_dict())
model.value_layer.load_state_dict(teacher.value_layer.state_dict())
model.log_std_parameter.data.copy_(teacher.log_std_parameter.data)
return model
def main() -> None:
parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module")
parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint")
parser.add_argument("--env", default="rotary_cartpole")
parser.add_argument("--robot-path", default="assets/rotary_cartpole")
parser.add_argument("--num-envs", type=int, default=64)
parser.add_argument("--iterations", type=int, default=2000)
parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--latent-dim", type=int, default=8)
parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128])
parser.add_argument("--history-length", type=int, default=10)
parser.add_argument("--output", default="checkpoints/adaptation.pt")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
device = torch.device(args.device)
hidden_sizes = tuple(args.hidden_sizes)
# ── Build env + runner (deploy mode with history + DR) ───────
env_cfg = OmegaConf.create({"env": {
"robot_path": args.robot_path,
}})
env = build_env(args.env, env_cfg)
runner_cfg = MuJoCoRunnerConfig(
num_envs=args.num_envs,
device=args.device,
history_length=args.history_length,
rma_mode="deploy",
domain_rand={
"qpos_noise_std": 0.01,
"qvel_noise_std": 0.5,
"action_delay_steps": [0, 2],
"friction_scale": [0.6, 1.6],
"damping_scale": [0.6, 1.6],
"torque_scale": [0.85, 1.15],
},
)
runner = MuJoCoRunner(env=env, config=runner_cfg)
raw_obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
mu_dim = runner.privileged_dim
log.info(
"adaptation_setup",
raw_obs_dim=raw_obs_dim,
act_dim=act_dim,
mu_dim=mu_dim,
latent_dim=args.latent_dim,
history_length=args.history_length,
)
# ── Load teacher & build deploy model ────────────────────────
# Teacher obs space: [raw_obs, μ]
teacher_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,),
)
teacher = _load_teacher_checkpoint(
path=args.checkpoint,
obs_space=teacher_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
mu_dim=mu_dim,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
teacher.eval()
for p in teacher.parameters():
p.requires_grad_(False)
# Deploy obs space: [raw_obs, history_flat]
step_dim = raw_obs_dim + act_dim
deploy_obs_space = spaces.Box(
low=-torch.inf, high=torch.inf,
shape=(raw_obs_dim + args.history_length * step_dim,),
)
deploy_model = _build_deploy_model(
teacher=teacher,
obs_space=deploy_obs_space,
act_space=env.action_space,
device=device,
raw_obs_dim=raw_obs_dim,
history_length=args.history_length,
hidden_sizes=hidden_sizes,
latent_dim=args.latent_dim,
)
# Freeze everything except the adaptation module.
for name, param in deploy_model.named_parameters():
if "adaptation_module" not in name:
param.requires_grad_(False)
optimizer = torch.optim.Adam(
deploy_model.adaptation_module.parameters(), lr=args.lr,
)
# ── Training loop ────────────────────────────────────────────
log.info("starting_adaptation_training", iterations=args.iterations)
obs, _ = runner.reset()
for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"):
# Collect a rollout using the deploy model.
z_targets: list[torch.Tensor] = []
z_preds: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
# Get action from deploy model (uses adaptation module).
aug_obs = obs # already augmented by runner
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
# Compute teacher's z from privileged μ.
mu = info.get("privileged_obs")
if mu is not None:
z_target = teacher.env_encoder(mu)
z_targets.append(z_target)
# Compute adaptation module's ẑ from history.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds.append(z_pred)
if not z_targets:
continue
# Supervised update on adaptation module.
z_target_batch = torch.cat(z_targets, dim=0).detach()
z_pred_batch = torch.cat(z_preds, dim=0)
# Re-compute z_pred with gradients (the ones above were no_grad).
# We need to re-encode from stored data; instead, collect with grad:
# Actually, z_preds were computed in no_grad. Let me re-collect
# a fresh batch with gradients.
obs_reset, _ = runner.reset()
obs = obs_reset
z_targets_grad: list[torch.Tensor] = []
z_preds_grad: list[torch.Tensor] = []
for _step in range(args.rollout_steps):
with torch.no_grad():
aug_obs = obs
actions = deploy_model.act(
{"states": aug_obs}, role="policy",
)[0]
obs, _, _, _, info = runner.step(actions)
mu = info.get("privileged_obs")
if mu is not None:
with torch.no_grad():
z_target = teacher.env_encoder(mu)
z_targets_grad.append(z_target)
# This time, compute z_pred WITH gradients.
raw = aug_obs[:, :raw_obs_dim]
hist_flat = aug_obs[:, raw_obs_dim:]
history = hist_flat.reshape(
-1, args.history_length, step_dim,
)
z_pred = deploy_model.adaptation_module(history)
z_preds_grad.append(z_pred)
if not z_targets_grad:
continue
z_target_all = torch.cat(z_targets_grad, dim=0).detach()
z_pred_all = torch.cat(z_preds_grad, dim=0)
loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if iteration % 50 == 0:
log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}")
# ── Save adaptation weights ──────────────────────────────────
out_path = pathlib.Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
# Save the full deploy model state dict.
torch.save(deploy_model.state_dict(), out_path)
log.info("adaptation_saved", path=str(out_path))
runner.close()
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@
Usage (simulation):
mjpython scripts/viz.py env=rotary_cartpole
mjpython scripts/viz.py env=cartpole +com=true
mjpython scripts/viz.py env=rotary_cartpole +com=true
Usage (real hardware — digital twin):
mjpython scripts/viz.py env=rotary_cartpole runner=serial
@@ -104,7 +104,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
@hydra.main(version_base=None, config_path="../configs", config_name="config")
def main(cfg: DictConfig) -> None:
choices = HydraConfig.get().runtime.choices
env_name = choices.get("env", "cartpole")
env_name = choices.get("env", "rotary_cartpole")
runner_name = choices.get("runner", "mujoco")
if runner_name == "serial":
@@ -229,11 +229,12 @@ def _main_serial(cfg: DictConfig, env_name: str) -> None:
_reset_flag[0] = False
serial_runner._send("M0")
serial_runner._drive_to_center()
serial_runner._wait_for_pendulum_still()
serial_runner._wait_for_settle()
logger.info("reset (drive-to-center + settle)")
# Send motor command to real hardware.
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
# Send motor command to real hardware (same PWM scaling as
# the policy path: ctrl_range-limited).
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
serial_runner._send(f"M{motor_speed}")
# Sync MuJoCo model with real sensor data.