♻️ full agent refactor
This commit is contained in:
@@ -75,14 +75,13 @@ def _infer_hidden_sizes(state_dict: dict[str, torch.Tensor]) -> tuple[int, ...]:
|
||||
|
||||
|
||||
def _infer_encoder_out_dim(state_dict: dict[str, torch.Tensor]) -> int | None:
|
||||
"""Return the history/adaptation encoder output dim, if present.
|
||||
"""Return the history encoder output dim, if present.
|
||||
|
||||
Lets eval reconstruct an embedding policy without knowing the training
|
||||
embedding_dim/latent_dim — read it straight from the saved weights.
|
||||
embedding_dim — read it straight from the saved weights.
|
||||
"""
|
||||
for key in ("history_encoder.fc.weight", "adaptation_module.fc.weight"):
|
||||
if key in state_dict:
|
||||
return state_dict[key].shape[0]
|
||||
if "history_encoder.fc.weight" in state_dict:
|
||||
return state_dict["history_encoder.fc.weight"].shape[0]
|
||||
return None
|
||||
|
||||
|
||||
@@ -92,14 +91,13 @@ def load_policy(
|
||||
action_space: spaces.Space,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
history_length: int = 0,
|
||||
rma_mode: str = "none",
|
||||
raw_obs_dim: int = 0,
|
||||
) -> tuple[SharedMLP, RunningStandardScaler]:
|
||||
"""Load a trained SharedMLP + observation normalizer from a checkpoint.
|
||||
|
||||
For DR + history-embedding policies (history_length > 0) or RMA deploy
|
||||
policies (rma_mode="deploy"), the history/adaptation encoder must be
|
||||
reconstructed too — its output dim is read back from the saved weights.
|
||||
For DR + history-embedding policies (history_length > 0), the history
|
||||
encoder is reconstructed too — its output dim is read back from the
|
||||
saved weights.
|
||||
|
||||
Returns:
|
||||
(model, state_preprocessor) ready for inference.
|
||||
@@ -117,11 +115,9 @@ def load_policy(
|
||||
action_space=action_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
history_length=history_length,
|
||||
rma_mode=rma_mode,
|
||||
history_length=history_length if enc_out else 0,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=enc_out or 32, # legacy "none" + history
|
||||
latent_dim=enc_out or 8, # RMA deploy adaptation module
|
||||
embedding_dim=enc_out or 32,
|
||||
)
|
||||
model.load_state_dict(ckpt["policy"])
|
||||
model.eval()
|
||||
@@ -189,7 +185,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco_single")
|
||||
|
||||
checkpoint_path = cfg.get("checkpoint", None)
|
||||
@@ -222,7 +218,6 @@ def _eval_sim(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, runner.observation_space, runner.action_space, device,
|
||||
history_length=runner.config.history_length,
|
||||
rma_mode=runner.config.rma_mode,
|
||||
raw_obs_dim=runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
@@ -311,7 +306,6 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
model, preprocessor = load_policy(
|
||||
checkpoint_path, serial_runner.observation_space, serial_runner.action_space, device,
|
||||
history_length=serial_runner.config.history_length,
|
||||
rma_mode=serial_runner.config.rma_mode,
|
||||
raw_obs_dim=serial_runner.env.observation_space.shape[0],
|
||||
)
|
||||
|
||||
@@ -339,9 +333,7 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
if _reset_flag[0]:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
obs, _ = serial_runner.reset()
|
||||
obs, _ = serial_runner.reset() # drives to center + settles
|
||||
step = 0
|
||||
episode += 1
|
||||
episode_reward = 0.0
|
||||
@@ -376,8 +368,8 @@ def _eval_serial(cfg: DictConfig, env_name: str, checkpoint_path: str) -> None:
|
||||
"step", n=step, reward=round(reward.item(), 3),
|
||||
action=round(action[0, 0].item(), 2),
|
||||
ep_reward=round(episode_reward, 1),
|
||||
motor_enc=state["encoder_count"],
|
||||
pend_deg=round(state["pendulum_angle"], 1),
|
||||
motor_deg=round(math.degrees(state["motor_rad"]), 1),
|
||||
pend_deg=round(math.degrees(state["pend_rad"]), 1),
|
||||
)
|
||||
|
||||
# Check for safety / disconnection.
|
||||
|
||||
@@ -352,7 +352,7 @@ def main() -> None:
|
||||
reuse_last_task_id=False,
|
||||
)
|
||||
task.set_base_docker(
|
||||
docker_image="registry.kube.optimize/worker-image:latest",
|
||||
docker_image="git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||
docker_arguments=[
|
||||
"-e", "CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1",
|
||||
"-e", "CLEARML_AGENT_SKIP_PIP_VENV_INSTALL=1",
|
||||
|
||||
@@ -63,7 +63,7 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
||||
"""Initialize ClearML task with project structure and tags."""
|
||||
Task.ignore_requirements("torch")
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
training_name = choices.get("training", "ppo")
|
||||
|
||||
@@ -113,7 +113,7 @@ def main(cfg: DictConfig) -> None:
|
||||
_valid_keys = {f.name for f in _dc.fields(TrainerConfig)}
|
||||
training_dict = {k: v for k, v in training_dict.items() if k in _valid_keys}
|
||||
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
env = build_env(env_name, cfg)
|
||||
runner = _build_runner(choices.get("runner", "mujoco"), env, cfg)
|
||||
trainer_config = TrainerConfig(**training_dict)
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
"""RMA Phase 2: Train the adaptation module φ(history) → ẑ.
|
||||
|
||||
Loads a Phase 1 (teacher) checkpoint, freezes the backbone + env_encoder,
|
||||
and trains a HistoryEncoder (adaptation module) to predict the teacher's
|
||||
latent z from observation-action history using supervised MSE.
|
||||
|
||||
Usage:
|
||||
python scripts/train_adaptation.py \
|
||||
--checkpoint runs/<run>/checkpoints/agent_XXXXX.pt \
|
||||
--env rotary_cartpole \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--num-envs 64 \
|
||||
--iterations 2000 \
|
||||
--lr 3e-4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
_PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
import structlog
|
||||
import torch
|
||||
import tqdm
|
||||
from gymnasium import spaces
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from src.core.registry import build_env
|
||||
from src.models.mlp import SharedMLP, EnvironmentEncoder, HistoryEncoder
|
||||
from src.runners.mujoco import MuJoCoRunner, MuJoCoRunnerConfig
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def _load_teacher_checkpoint(
|
||||
path: str, obs_space: spaces.Space, act_space: spaces.Space,
|
||||
device: torch.device, raw_obs_dim: int, mu_dim: int,
|
||||
hidden_sizes: tuple[int, ...], latent_dim: int,
|
||||
) -> SharedMLP:
|
||||
"""Reconstruct the teacher SharedMLP and load saved weights."""
|
||||
model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
rma_mode="teacher",
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
mu_dim=mu_dim,
|
||||
latent_dim=latent_dim,
|
||||
)
|
||||
ckpt = torch.load(path, map_location=device, weights_only=True)
|
||||
# skrl saves under "policy" key with state_dict.
|
||||
if "policy" in ckpt:
|
||||
model.load_state_dict(ckpt["policy"])
|
||||
else:
|
||||
model.load_state_dict(ckpt)
|
||||
return model
|
||||
|
||||
|
||||
def _build_deploy_model(
|
||||
teacher: SharedMLP,
|
||||
obs_space: spaces.Space,
|
||||
act_space: spaces.Space,
|
||||
device: torch.device,
|
||||
raw_obs_dim: int,
|
||||
history_length: int,
|
||||
hidden_sizes: tuple[int, ...],
|
||||
latent_dim: int,
|
||||
) -> SharedMLP:
|
||||
"""Create a deploy-mode SharedMLP and copy backbone + heads from teacher."""
|
||||
model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=hidden_sizes,
|
||||
rma_mode="deploy",
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
history_length=history_length,
|
||||
latent_dim=latent_dim,
|
||||
)
|
||||
|
||||
# Copy backbone, policy head, value head from teacher.
|
||||
model.net.load_state_dict(teacher.net.state_dict())
|
||||
model.mean_layer.load_state_dict(teacher.mean_layer.state_dict())
|
||||
model.value_layer.load_state_dict(teacher.value_layer.state_dict())
|
||||
model.log_std_parameter.data.copy_(teacher.log_std_parameter.data)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="RMA Phase 2: train adaptation module")
|
||||
parser.add_argument("--checkpoint", required=True, help="Path to Phase 1 teacher checkpoint")
|
||||
parser.add_argument("--env", default="rotary_cartpole")
|
||||
parser.add_argument("--robot-path", default="assets/rotary_cartpole")
|
||||
parser.add_argument("--num-envs", type=int, default=64)
|
||||
parser.add_argument("--iterations", type=int, default=2000)
|
||||
parser.add_argument("--rollout-steps", type=int, default=256, help="Steps per rollout")
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
parser.add_argument("--latent-dim", type=int, default=8)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="+", default=[128, 128])
|
||||
parser.add_argument("--history-length", type=int, default=10)
|
||||
parser.add_argument("--output", default="checkpoints/adaptation.pt")
|
||||
parser.add_argument("--device", default="cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device)
|
||||
hidden_sizes = tuple(args.hidden_sizes)
|
||||
|
||||
# ── Build env + runner (deploy mode with history + DR) ───────
|
||||
env_cfg = OmegaConf.create({"env": {
|
||||
"robot_path": args.robot_path,
|
||||
}})
|
||||
env = build_env(args.env, env_cfg)
|
||||
|
||||
runner_cfg = MuJoCoRunnerConfig(
|
||||
num_envs=args.num_envs,
|
||||
device=args.device,
|
||||
history_length=args.history_length,
|
||||
rma_mode="deploy",
|
||||
domain_rand={
|
||||
"qpos_noise_std": 0.01,
|
||||
"qvel_noise_std": 0.5,
|
||||
"action_delay_steps": [0, 2],
|
||||
"friction_scale": [0.6, 1.6],
|
||||
"damping_scale": [0.6, 1.6],
|
||||
"torque_scale": [0.85, 1.15],
|
||||
},
|
||||
)
|
||||
runner = MuJoCoRunner(env=env, config=runner_cfg)
|
||||
|
||||
raw_obs_dim = env.observation_space.shape[0]
|
||||
act_dim = env.action_space.shape[0]
|
||||
mu_dim = runner.privileged_dim
|
||||
|
||||
log.info(
|
||||
"adaptation_setup",
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
act_dim=act_dim,
|
||||
mu_dim=mu_dim,
|
||||
latent_dim=args.latent_dim,
|
||||
history_length=args.history_length,
|
||||
)
|
||||
|
||||
# ── Load teacher & build deploy model ────────────────────────
|
||||
# Teacher obs space: [raw_obs, μ]
|
||||
teacher_obs_space = spaces.Box(
|
||||
low=-torch.inf, high=torch.inf, shape=(raw_obs_dim + mu_dim,),
|
||||
)
|
||||
|
||||
teacher = _load_teacher_checkpoint(
|
||||
path=args.checkpoint,
|
||||
obs_space=teacher_obs_space,
|
||||
act_space=env.action_space,
|
||||
device=device,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
mu_dim=mu_dim,
|
||||
hidden_sizes=hidden_sizes,
|
||||
latent_dim=args.latent_dim,
|
||||
)
|
||||
teacher.eval()
|
||||
for p in teacher.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Deploy obs space: [raw_obs, history_flat]
|
||||
step_dim = raw_obs_dim + act_dim
|
||||
deploy_obs_space = spaces.Box(
|
||||
low=-torch.inf, high=torch.inf,
|
||||
shape=(raw_obs_dim + args.history_length * step_dim,),
|
||||
)
|
||||
|
||||
deploy_model = _build_deploy_model(
|
||||
teacher=teacher,
|
||||
obs_space=deploy_obs_space,
|
||||
act_space=env.action_space,
|
||||
device=device,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
history_length=args.history_length,
|
||||
hidden_sizes=hidden_sizes,
|
||||
latent_dim=args.latent_dim,
|
||||
)
|
||||
|
||||
# Freeze everything except the adaptation module.
|
||||
for name, param in deploy_model.named_parameters():
|
||||
if "adaptation_module" not in name:
|
||||
param.requires_grad_(False)
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
deploy_model.adaptation_module.parameters(), lr=args.lr,
|
||||
)
|
||||
|
||||
# ── Training loop ────────────────────────────────────────────
|
||||
log.info("starting_adaptation_training", iterations=args.iterations)
|
||||
obs, _ = runner.reset()
|
||||
|
||||
for iteration in tqdm.tqdm(range(args.iterations), desc="Adaptation"):
|
||||
# Collect a rollout using the deploy model.
|
||||
z_targets: list[torch.Tensor] = []
|
||||
z_preds: list[torch.Tensor] = []
|
||||
|
||||
for _step in range(args.rollout_steps):
|
||||
with torch.no_grad():
|
||||
# Get action from deploy model (uses adaptation module).
|
||||
aug_obs = obs # already augmented by runner
|
||||
actions = deploy_model.act(
|
||||
{"states": aug_obs}, role="policy",
|
||||
)[0]
|
||||
|
||||
obs, _, _, _, info = runner.step(actions)
|
||||
|
||||
# Compute teacher's z from privileged μ.
|
||||
mu = info.get("privileged_obs")
|
||||
if mu is not None:
|
||||
z_target = teacher.env_encoder(mu)
|
||||
z_targets.append(z_target)
|
||||
|
||||
# Compute adaptation module's ẑ from history.
|
||||
raw = aug_obs[:, :raw_obs_dim]
|
||||
hist_flat = aug_obs[:, raw_obs_dim:]
|
||||
history = hist_flat.reshape(
|
||||
-1, args.history_length, step_dim,
|
||||
)
|
||||
z_pred = deploy_model.adaptation_module(history)
|
||||
z_preds.append(z_pred)
|
||||
|
||||
if not z_targets:
|
||||
continue
|
||||
|
||||
# Supervised update on adaptation module.
|
||||
z_target_batch = torch.cat(z_targets, dim=0).detach()
|
||||
z_pred_batch = torch.cat(z_preds, dim=0)
|
||||
|
||||
# Re-compute z_pred with gradients (the ones above were no_grad).
|
||||
# We need to re-encode from stored data; instead, collect with grad:
|
||||
# Actually, z_preds were computed in no_grad. Let me re-collect
|
||||
# a fresh batch with gradients.
|
||||
obs_reset, _ = runner.reset()
|
||||
obs = obs_reset
|
||||
|
||||
z_targets_grad: list[torch.Tensor] = []
|
||||
z_preds_grad: list[torch.Tensor] = []
|
||||
|
||||
for _step in range(args.rollout_steps):
|
||||
with torch.no_grad():
|
||||
aug_obs = obs
|
||||
actions = deploy_model.act(
|
||||
{"states": aug_obs}, role="policy",
|
||||
)[0]
|
||||
obs, _, _, _, info = runner.step(actions)
|
||||
mu = info.get("privileged_obs")
|
||||
|
||||
if mu is not None:
|
||||
with torch.no_grad():
|
||||
z_target = teacher.env_encoder(mu)
|
||||
z_targets_grad.append(z_target)
|
||||
|
||||
# This time, compute z_pred WITH gradients.
|
||||
raw = aug_obs[:, :raw_obs_dim]
|
||||
hist_flat = aug_obs[:, raw_obs_dim:]
|
||||
history = hist_flat.reshape(
|
||||
-1, args.history_length, step_dim,
|
||||
)
|
||||
z_pred = deploy_model.adaptation_module(history)
|
||||
z_preds_grad.append(z_pred)
|
||||
|
||||
if not z_targets_grad:
|
||||
continue
|
||||
|
||||
z_target_all = torch.cat(z_targets_grad, dim=0).detach()
|
||||
z_pred_all = torch.cat(z_preds_grad, dim=0)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(z_pred_all, z_target_all)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if iteration % 50 == 0:
|
||||
log.info("adaptation_loss", iteration=iteration, mse=f"{loss.item():.6f}")
|
||||
|
||||
# ── Save adaptation weights ──────────────────────────────────
|
||||
out_path = pathlib.Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the full deploy model state dict.
|
||||
torch.save(deploy_model.state_dict(), out_path)
|
||||
log.info("adaptation_saved", path=str(out_path))
|
||||
|
||||
runner.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Usage (simulation):
|
||||
mjpython scripts/viz.py env=rotary_cartpole
|
||||
mjpython scripts/viz.py env=cartpole +com=true
|
||||
mjpython scripts/viz.py env=rotary_cartpole +com=true
|
||||
|
||||
Usage (real hardware — digital twin):
|
||||
mjpython scripts/viz.py env=rotary_cartpole runner=serial
|
||||
@@ -104,7 +104,7 @@ def _add_action_arrow(viewer, model, data, action_val: float) -> None:
|
||||
@hydra.main(version_base=None, config_path="../configs", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
choices = HydraConfig.get().runtime.choices
|
||||
env_name = choices.get("env", "cartpole")
|
||||
env_name = choices.get("env", "rotary_cartpole")
|
||||
runner_name = choices.get("runner", "mujoco")
|
||||
|
||||
if runner_name == "serial":
|
||||
@@ -229,11 +229,12 @@ def _main_serial(cfg: DictConfig, env_name: str) -> None:
|
||||
_reset_flag[0] = False
|
||||
serial_runner._send("M0")
|
||||
serial_runner._drive_to_center()
|
||||
serial_runner._wait_for_pendulum_still()
|
||||
serial_runner._wait_for_settle()
|
||||
logger.info("reset (drive-to-center + settle)")
|
||||
|
||||
# Send motor command to real hardware.
|
||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * 255)
|
||||
# Send motor command to real hardware (same PWM scaling as
|
||||
# the policy path: ctrl_range-limited).
|
||||
motor_speed = int(np.clip(action_val, -1.0, 1.0) * serial_runner._max_pwm)
|
||||
serial_runner._send(f"M{motor_speed}")
|
||||
|
||||
# Sync MuJoCo model with real sensor data.
|
||||
|
||||
Reference in New Issue
Block a user