feat: RMA-style history-conditioned policy for sim2real adaptation

Added a temporal observation history buffer and 1D-CNN encoder so the
policy can implicitly infer environment parameters (mass, friction,
gear ratios, etc.) from recent (obs, action) dynamics.

Architecture:
  history window [(obs₀,a₀), ..., (obs_{H-1},a_{H-1})]
      → 1D-CNN HistoryEncoder → embedding (32-dim)
      → concat [current_obs, embedding] → MLP → action

Components:
- BaseRunner: history ring buffer, _push_history/_reset_history,
  augmented obs space (6 + H×7 = 76 with H=10)
- HistoryEncoder (src/models/mlp.py): 2-layer temporal Conv1d + GAP
- SharedMLP: optional history_length/raw_obs_dim/embedding_dim params;
  splits augmented obs, encodes history, feeds [obs, emb] to MLP
- TrainerConfig: history_length, embedding_dim fields
- All runner configs: history_length=10 by default
- Tests: encoder shape, model with/without history, config defaults
This commit is contained in:
2026-03-28 18:58:24 +01:00
parent 8ed9afe583
commit 8cc84d6a21
9 changed files with 209 additions and 9 deletions

View File

@@ -2,3 +2,4 @@ num_envs: 1024 # MJX shines with many parallel envs
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 10 substeps: 10
history_length: 10 # RMA-style: 10-step window of (obs, action) pairs

View File

@@ -2,6 +2,7 @@ num_envs: 64
device: auto # auto = cuda if available, else cpu device: auto # auto = cuda if available, else cpu
dt: 0.002 dt: 0.002
substeps: 10 substeps: 10
history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
# ── Sim2real: domain randomization ─────────────────────────────── # ── Sim2real: domain randomization ───────────────────────────────
domain_rand: domain_rand:

View File

@@ -5,3 +5,4 @@ num_envs: 1
device: cpu device: cpu
dt: 0.002 dt: 0.002
substeps: 10 substeps: 10
history_length: 10

View File

@@ -8,3 +8,4 @@ port: /dev/cu.usbserial-0001
baud: 115200 baud: 115200
dt: 0.02 # control loop period (50 Hz, matches training) dt: 0.02 # control loop period (50 Hz, matches training)
no_data_timeout: 2.0 # seconds of silence before declaring disconnect no_data_timeout: 2.0 # seconds of silence before declaring disconnect
history_length: 10 # must match training runner

View File

@@ -18,6 +18,10 @@ max_log_std: 2.0
record_video_every: 10000 record_video_every: 10000
# RMA-style history encoder
history_length: 10 # temporal window (must match runner)
embedding_dim: 32 # history encoder output dimension
# ClearML remote execution (GPU worker) # ClearML remote execution (GPU worker)
remote: false remote: false

View File

@@ -14,6 +14,7 @@ T = TypeVar("T")
class BaseRunnerConfig: class BaseRunnerConfig:
num_envs: int = 1 num_envs: int = 1
device: str = "cpu" device: str = "cpu"
history_length: int = 0 # 0 = no history (single obs), >0 = RMA-style
class BaseRunner(abc.ABC, Generic[T]): class BaseRunner(abc.ABC, Generic[T]):
def __init__(self, env: BaseEnv, config: T) -> None: def __init__(self, env: BaseEnv, config: T) -> None:
@@ -36,6 +37,26 @@ class BaseRunner(abc.ABC, Generic[T]):
self.config.num_envs, dtype=torch.long, device=self.config.device self.config.num_envs, dtype=torch.long, device=self.config.device
) )
# ── History buffer (RMA-style adaptation) ────────────────
self._history_len: int = getattr(self.config, "history_length", 0)
if self._history_len > 0:
obs_dim = self.observation_space.shape[0]
act_dim = self.action_space.shape[0]
self._history_obs_dim = obs_dim
self._history_act_dim = act_dim
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
# Ring buffer: (num_envs, history_length, obs_dim + act_dim)
self._history_buf = torch.zeros(
self.config.num_envs, self._history_len, self._history_step_dim,
device=self.config.device,
)
# Augmented observation space: [current_obs, history_flat]
from gymnasium import spaces
aug_dim = obs_dim + self._history_len * self._history_step_dim
self.observation_space = spaces.Box(
low=-torch.inf, high=torch.inf, shape=(aug_dim,)
)
@property @property
@abc.abstractmethod @abc.abstractmethod
def num_envs(self) -> int: def num_envs(self) -> int:
@@ -63,14 +84,44 @@ class BaseRunner(abc.ABC, Generic[T]):
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
self._offscreen_renderer.close() self._offscreen_renderer.close()
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
"""Concatenate history buffer to current obs if history is enabled."""
if self._history_len <= 0:
return obs
# Flatten history: (num_envs, H, step_dim) → (num_envs, H * step_dim)
hist_flat = self._history_buf.reshape(obs.shape[0], -1)
return torch.cat([obs, hist_flat], dim=-1)
def _push_history(self, obs: torch.Tensor, actions: torch.Tensor,
env_ids: torch.Tensor | None = None) -> None:
"""Push (obs, action) into the ring buffer (shift left, append right)."""
if self._history_len <= 0:
return
step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1)
if env_ids is None:
# All envs.
self._history_buf = torch.roll(self._history_buf, -1, dims=1)
self._history_buf[:, -1] = step
else:
self._history_buf[env_ids] = torch.roll(
self._history_buf[env_ids], -1, dims=1
)
self._history_buf[env_ids, -1] = step[env_ids]
def _reset_history(self, env_ids: torch.Tensor) -> None:
"""Zero the history buffer for reset envs."""
if self._history_len > 0:
self._history_buf[env_ids] = 0.0
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device) all_ids = torch.arange(self.num_envs, device=self.device)
qpos, qvel = self._sim_reset(all_ids) qpos, qvel = self._sim_reset(all_ids)
self.step_counts.zero_() self.step_counts.zero_()
self._reset_history(all_ids)
state = self.env.build_state(qpos, qvel) state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state) obs = self.env.compute_observations(state)
return obs, {} return self._augment_obs(obs), {}
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
self._last_actions = actions self._last_actions = actions
@@ -83,23 +134,27 @@ class BaseRunner(abc.ABC, Generic[T]):
terminated = self.env.compute_terminations(state) terminated = self.env.compute_terminations(state)
truncated = self.env.compute_truncations(self.step_counts) truncated = self.env.compute_truncations(self.step_counts)
# Push current (obs, action) into history before augmenting.
self._push_history(obs, actions)
info: dict[str, Any] = {} info: dict[str, Any] = {}
done = terminated | truncated done = terminated | truncated
done_ids = done.nonzero(as_tuple=False).squeeze(-1) done_ids = done.nonzero(as_tuple=False).squeeze(-1)
if done_ids.numel() > 0: if done_ids.numel() > 0:
info["final_observations"] = obs[done_ids].clone() info["final_observations"] = self._augment_obs(obs)[done_ids].clone()
info["final_env_ids"] = done_ids.clone() info["final_env_ids"] = done_ids.clone()
reset_qpos, reset_qvel = self._sim_reset(done_ids) reset_qpos, reset_qvel = self._sim_reset(done_ids)
self.step_counts[done_ids] = 0 self.step_counts[done_ids] = 0
self._reset_history(done_ids)
reset_state = self.env.build_state(reset_qpos, reset_qvel) reset_state = self.env.build_state(reset_qpos, reset_qvel)
obs[done_ids] = self.env.compute_observations(reset_state) obs[done_ids] = self.env.compute_observations(reset_state)
# skrl expects (num_envs, 1) for rewards/terminated/truncated # skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
def _render_frame(self, env_idx: int = 0) -> np.ndarray: def _render_frame(self, env_idx: int = 0) -> np.ndarray:
"""Return a raw RGB frame. Override in subclass.""" """Return a raw RGB frame. Override in subclass."""

View File

@@ -3,14 +3,72 @@ import torch.nn as nn
from gymnasium import spaces from gymnasium import spaces
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
class HistoryEncoder(nn.Module):
"""1D-CNN encoder over a temporal window of (obs, action) pairs.
Input: (batch, history_length, step_dim)
Output: (batch, embedding_dim)
Architecture: two temporal conv layers → global average pool → linear.
This lets the policy implicitly infer environment parameters (mass,
friction, gear, etc.) from recent dynamics — the core of RMA-style
adaptation for sim2real.
"""
def __init__(
self,
history_length: int,
step_dim: int,
embedding_dim: int = 32,
hidden_channels: int = 32,
) -> None:
super().__init__()
self.conv = nn.Sequential(
# (batch, step_dim, history_length) after transpose
nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1),
nn.ELU(),
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
nn.ELU(),
)
self.fc = nn.Linear(hidden_channels, embedding_dim)
def forward(self, history: torch.Tensor) -> torch.Tensor:
"""history: (batch, history_length, step_dim)."""
# Conv1d expects (batch, channels, seq_len).
x = history.transpose(1, 2)
x = self.conv(x)
# Global average pool over time.
x = x.mean(dim=-1)
return self.fc(x)
class SharedMLP(GaussianMixin, DeterministicMixin, Model): class SharedMLP(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0): def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0, history_length: int = 0, raw_obs_dim: int = 0, embedding_dim: int = 32):
Model.__init__(self, observation_space, action_space, device) Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
DeterministicMixin.__init__(self, clip_actions) DeterministicMixin.__init__(self, clip_actions)
layers = [] self._history_length = history_length
in_dim: int = self.num_observations self._raw_obs_dim = raw_obs_dim
self._embedding_dim = embedding_dim
if history_length > 0 and raw_obs_dim > 0:
# The observation is [current_obs(raw_obs_dim), history_flat(H * step_dim)].
act_dim = self.num_actions
step_dim = raw_obs_dim + act_dim
self.history_encoder = HistoryEncoder(
history_length=history_length,
step_dim=step_dim,
embedding_dim=embedding_dim,
)
# MLP input = raw obs + history embedding.
in_dim = raw_obs_dim + embedding_dim
else:
self.history_encoder = None
in_dim = self.num_observations
layers: list[nn.Module] = []
for hidden_size in hidden_sizes: for hidden_size in hidden_sizes:
layers.append(nn.Linear(in_dim, hidden_size)) layers.append(nn.Linear(in_dim, hidden_size))
layers.append(nn.ELU()) layers.append(nn.ELU())
@@ -32,17 +90,29 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
elif role == "value": elif role == "value":
return DeterministicMixin.act(self, inputs, role) return DeterministicMixin.act(self, inputs, role)
def _encode(self, states: torch.Tensor) -> torch.Tensor:
"""Split augmented obs into current obs + history, encode, concat."""
if self.history_encoder is None:
return self.net(states)
obs = states[:, :self._raw_obs_dim]
hist_flat = states[:, self._raw_obs_dim:]
step_dim = self._raw_obs_dim + self.num_actions
history = hist_flat.reshape(-1, self._history_length, step_dim)
embedding = self.history_encoder(history)
return self.net(torch.cat([obs, embedding], dim=-1))
def compute( def compute(
self, inputs: dict[str, torch.Tensor], role: str = "" self, inputs: dict[str, torch.Tensor], role: str = ""
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
if role == "policy": if role == "policy":
self._shared_output = self.net(inputs["states"]) self._shared_output = self._encode(inputs["states"])
return self.mean_layer(self._shared_output), self.log_std_parameter, {} return self.mean_layer(self._shared_output), self.log_std_parameter, {}
elif role == "value": elif role == "value":
shared_output = ( shared_output = (
self._shared_output self._shared_output
if self._shared_output is not None if self._shared_output is not None
else self.net(inputs["states"]) else self._encode(inputs["states"])
) )
self._shared_output = None self._shared_output = None
return self.value_layer(shared_output), {} return self.value_layer(shared_output), {}

View File

@@ -48,6 +48,10 @@ class TrainerConfig:
record_video_every: int = 10_000 # 0 = disabled record_video_every: int = 10_000 # 0 = disabled
record_video_fps: int = 0 # 0 = derive from sim dt×substeps record_video_fps: int = 0 # 0 = derive from sim dt×substeps
# History encoder (RMA-style adaptation)
history_length: int = 0 # 0 = disabled, >0 = temporal window size
embedding_dim: int = 32 # history encoder output dimension
# ── Video-recording trainer ────────────────────────────────────────── # ── Video-recording trainer ──────────────────────────────────────────
@@ -173,6 +177,11 @@ class Trainer:
device=device, device=device,
) )
# Determine raw obs dim (without history augmentation).
raw_obs_dim = 0
if self.config.history_length > 0:
raw_obs_dim = self.runner.env.observation_space.shape[0]
self.model = SharedMLP( self.model = SharedMLP(
observation_space=obs_space, observation_space=obs_space,
action_space=act_space, action_space=act_space,
@@ -181,6 +190,9 @@ class Trainer:
initial_log_std=self.config.initial_log_std, initial_log_std=self.config.initial_log_std,
min_log_std=self.config.min_log_std, min_log_std=self.config.min_log_std,
max_log_std=self.config.max_log_std, max_log_std=self.config.max_log_std,
history_length=self.config.history_length,
raw_obs_dim=raw_obs_dim,
embedding_dim=self.config.embedding_dim,
) )
models = {"policy": self.model, "value": self.model} models = {"policy": self.model, "value": self.model}

View File

@@ -1,11 +1,14 @@
"""Unit tests for MuJoCoRunner domain randomization.""" """Unit tests for MuJoCoRunner domain randomization and history buffer."""
import dataclasses import dataclasses
import numpy as np import numpy as np
import pytest import pytest
import torch
from gymnasium import spaces
from src.runners.mujoco import DomainRandConfig, MuJoCoRunnerConfig from src.runners.mujoco import DomainRandConfig, MuJoCoRunnerConfig
from src.models.mlp import SharedMLP, HistoryEncoder
class TestDomainRandConfig: class TestDomainRandConfig:
@@ -37,3 +40,55 @@ class TestMuJoCoRunnerConfig:
assert isinstance(cfg.domain_rand, DomainRandConfig) assert isinstance(cfg.domain_rand, DomainRandConfig)
assert cfg.domain_rand.mass_frac == 0.2 assert cfg.domain_rand.mass_frac == 0.2
assert cfg.domain_rand.friction_frac == 0.3 assert cfg.domain_rand.friction_frac == 0.3
def test_history_length_default(self) -> None:
cfg = MuJoCoRunnerConfig()
assert cfg.history_length == 0
class TestHistoryEncoder:
def test_output_shape(self) -> None:
enc = HistoryEncoder(history_length=10, step_dim=7, embedding_dim=32)
x = torch.randn(4, 10, 7) # batch=4, H=10, step_dim=7
out = enc(x)
assert out.shape == (4, 32)
def test_different_embedding_dim(self) -> None:
enc = HistoryEncoder(history_length=5, step_dim=7, embedding_dim=16)
x = torch.randn(2, 5, 7)
out = enc(x)
assert out.shape == (2, 16)
class TestSharedMLPWithHistory:
def test_no_history(self) -> None:
"""Without history, model works as before."""
obs_space = spaces.Box(low=-1.0, high=1.0, shape=(6,))
act_space = spaces.Box(low=-1.0, high=1.0, shape=(1,))
model = SharedMLP(obs_space, act_space, torch.device("cpu"),
hidden_sizes=(32, 32))
assert model.history_encoder is None
inp = {"states": torch.randn(4, 6)}
mean, log_std, _ = model.compute(inp, role="policy")
assert mean.shape == (4, 1)
def test_with_history(self) -> None:
"""With history, model splits obs and encodes history."""
raw_obs_dim = 6
act_dim = 1
H = 10
step_dim = raw_obs_dim + act_dim # 7
aug_dim = raw_obs_dim + H * step_dim # 6 + 70 = 76
obs_space = spaces.Box(low=-1.0, high=1.0, shape=(aug_dim,))
act_space = spaces.Box(low=-1.0, high=1.0, shape=(act_dim,))
model = SharedMLP(obs_space, act_space, torch.device("cpu"),
hidden_sizes=(32, 32),
history_length=H, raw_obs_dim=raw_obs_dim,
embedding_dim=32)
assert model.history_encoder is not None
inp = {"states": torch.randn(4, aug_dim)}
mean, log_std, _ = model.compute(inp, role="policy")
assert mean.shape == (4, act_dim)
value, _ = model.compute(inp, role="value")
assert value.shape == (4, 1)