feat: RMA-style history-conditioned policy for sim2real adaptation
Added a temporal observation history buffer and 1D-CNN encoder so the
policy can implicitly infer environment parameters (mass, friction,
gear ratios, etc.) from recent (obs, action) dynamics.
Architecture:
history window [(obs₀,a₀), ..., (obs_{H-1},a_{H-1})]
→ 1D-CNN HistoryEncoder → embedding (32-dim)
→ concat [current_obs, embedding] → MLP → action
Components:
- BaseRunner: history ring buffer, _push_history/_reset_history,
augmented obs space (6 + H×7 = 76 with H=10)
- HistoryEncoder (src/models/mlp.py): 2-layer temporal Conv1d + GAP
- SharedMLP: optional history_length/raw_obs_dim/embedding_dim params;
splits augmented obs, encodes history, feeds [obs, emb] to MLP
- TrainerConfig: history_length, embedding_dim fields
- All runner configs: history_length=10 by default
- Tests: encoder shape, model with/without history, config defaults
This commit is contained in:
@@ -2,3 +2,4 @@ num_envs: 1024 # MJX shines with many parallel envs
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
|
||||
|
||||
@@ -2,6 +2,7 @@ num_envs: 64
|
||||
device: auto # auto = cuda if available, else cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10 # RMA-style: 10-step window of (obs, action) pairs
|
||||
|
||||
# ── Sim2real: domain randomization ───────────────────────────────
|
||||
domain_rand:
|
||||
|
||||
@@ -5,3 +5,4 @@ num_envs: 1
|
||||
device: cpu
|
||||
dt: 0.002
|
||||
substeps: 10
|
||||
history_length: 10
|
||||
|
||||
@@ -8,3 +8,4 @@ port: /dev/cu.usbserial-0001
|
||||
baud: 115200
|
||||
dt: 0.02 # control loop period (50 Hz, matches training)
|
||||
no_data_timeout: 2.0 # seconds of silence before declaring disconnect
|
||||
history_length: 10 # must match training runner
|
||||
|
||||
@@ -18,6 +18,10 @@ max_log_std: 2.0
|
||||
|
||||
record_video_every: 10000
|
||||
|
||||
# RMA-style history encoder
|
||||
history_length: 10 # temporal window (must match runner)
|
||||
embedding_dim: 32 # history encoder output dimension
|
||||
|
||||
# ClearML remote execution (GPU worker)
|
||||
remote: false
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ T = TypeVar("T")
|
||||
class BaseRunnerConfig:
|
||||
num_envs: int = 1
|
||||
device: str = "cpu"
|
||||
history_length: int = 0 # 0 = no history (single obs), >0 = RMA-style
|
||||
|
||||
class BaseRunner(abc.ABC, Generic[T]):
|
||||
def __init__(self, env: BaseEnv, config: T) -> None:
|
||||
@@ -36,6 +37,26 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
self.config.num_envs, dtype=torch.long, device=self.config.device
|
||||
)
|
||||
|
||||
# ── History buffer (RMA-style adaptation) ────────────────
|
||||
self._history_len: int = getattr(self.config, "history_length", 0)
|
||||
if self._history_len > 0:
|
||||
obs_dim = self.observation_space.shape[0]
|
||||
act_dim = self.action_space.shape[0]
|
||||
self._history_obs_dim = obs_dim
|
||||
self._history_act_dim = act_dim
|
||||
self._history_step_dim = obs_dim + act_dim # each step stores (obs, action)
|
||||
# Ring buffer: (num_envs, history_length, obs_dim + act_dim)
|
||||
self._history_buf = torch.zeros(
|
||||
self.config.num_envs, self._history_len, self._history_step_dim,
|
||||
device=self.config.device,
|
||||
)
|
||||
# Augmented observation space: [current_obs, history_flat]
|
||||
from gymnasium import spaces
|
||||
aug_dim = obs_dim + self._history_len * self._history_step_dim
|
||||
self.observation_space = spaces.Box(
|
||||
low=-torch.inf, high=torch.inf, shape=(aug_dim,)
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def num_envs(self) -> int:
|
||||
@@ -63,14 +84,44 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
"""Concatenate history buffer to current obs if history is enabled."""
|
||||
if self._history_len <= 0:
|
||||
return obs
|
||||
# Flatten history: (num_envs, H, step_dim) → (num_envs, H * step_dim)
|
||||
hist_flat = self._history_buf.reshape(obs.shape[0], -1)
|
||||
return torch.cat([obs, hist_flat], dim=-1)
|
||||
|
||||
def _push_history(self, obs: torch.Tensor, actions: torch.Tensor,
|
||||
env_ids: torch.Tensor | None = None) -> None:
|
||||
"""Push (obs, action) into the ring buffer (shift left, append right)."""
|
||||
if self._history_len <= 0:
|
||||
return
|
||||
step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1)
|
||||
if env_ids is None:
|
||||
# All envs.
|
||||
self._history_buf = torch.roll(self._history_buf, -1, dims=1)
|
||||
self._history_buf[:, -1] = step
|
||||
else:
|
||||
self._history_buf[env_ids] = torch.roll(
|
||||
self._history_buf[env_ids], -1, dims=1
|
||||
)
|
||||
self._history_buf[env_ids, -1] = step[env_ids]
|
||||
|
||||
def _reset_history(self, env_ids: torch.Tensor) -> None:
|
||||
"""Zero the history buffer for reset envs."""
|
||||
if self._history_len > 0:
|
||||
self._history_buf[env_ids] = 0.0
|
||||
|
||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
qpos, qvel = self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
self._reset_history(all_ids)
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
return obs, {}
|
||||
return self._augment_obs(obs), {}
|
||||
|
||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
self._last_actions = actions
|
||||
@@ -83,23 +134,27 @@ class BaseRunner(abc.ABC, Generic[T]):
|
||||
terminated = self.env.compute_terminations(state)
|
||||
truncated = self.env.compute_truncations(self.step_counts)
|
||||
|
||||
# Push current (obs, action) into history before augmenting.
|
||||
self._push_history(obs, actions)
|
||||
|
||||
info: dict[str, Any] = {}
|
||||
|
||||
done = terminated | truncated
|
||||
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if done_ids.numel() > 0:
|
||||
info["final_observations"] = obs[done_ids].clone()
|
||||
info["final_observations"] = self._augment_obs(obs)[done_ids].clone()
|
||||
info["final_env_ids"] = done_ids.clone()
|
||||
|
||||
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
||||
self.step_counts[done_ids] = 0
|
||||
self._reset_history(done_ids)
|
||||
|
||||
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
||||
obs[done_ids] = self.env.compute_observations(reset_state)
|
||||
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def _render_frame(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Return a raw RGB frame. Override in subclass."""
|
||||
|
||||
@@ -3,14 +3,72 @@ import torch.nn as nn
|
||||
from gymnasium import spaces
|
||||
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
|
||||
|
||||
|
||||
class HistoryEncoder(nn.Module):
|
||||
"""1D-CNN encoder over a temporal window of (obs, action) pairs.
|
||||
|
||||
Input: (batch, history_length, step_dim)
|
||||
Output: (batch, embedding_dim)
|
||||
|
||||
Architecture: two temporal conv layers → global average pool → linear.
|
||||
This lets the policy implicitly infer environment parameters (mass,
|
||||
friction, gear, etc.) from recent dynamics — the core of RMA-style
|
||||
adaptation for sim2real.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
history_length: int,
|
||||
step_dim: int,
|
||||
embedding_dim: int = 32,
|
||||
hidden_channels: int = 32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
# (batch, step_dim, history_length) after transpose
|
||||
nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.fc = nn.Linear(hidden_channels, embedding_dim)
|
||||
|
||||
def forward(self, history: torch.Tensor) -> torch.Tensor:
|
||||
"""history: (batch, history_length, step_dim)."""
|
||||
# Conv1d expects (batch, channels, seq_len).
|
||||
x = history.transpose(1, 2)
|
||||
x = self.conv(x)
|
||||
# Global average pool over time.
|
||||
x = x.mean(dim=-1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0):
|
||||
def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0, history_length: int = 0, raw_obs_dim: int = 0, embedding_dim: int = 32):
|
||||
Model.__init__(self, observation_space, action_space, device)
|
||||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
|
||||
DeterministicMixin.__init__(self, clip_actions)
|
||||
|
||||
layers = []
|
||||
in_dim: int = self.num_observations
|
||||
self._history_length = history_length
|
||||
self._raw_obs_dim = raw_obs_dim
|
||||
self._embedding_dim = embedding_dim
|
||||
|
||||
if history_length > 0 and raw_obs_dim > 0:
|
||||
# The observation is [current_obs(raw_obs_dim), history_flat(H * step_dim)].
|
||||
act_dim = self.num_actions
|
||||
step_dim = raw_obs_dim + act_dim
|
||||
self.history_encoder = HistoryEncoder(
|
||||
history_length=history_length,
|
||||
step_dim=step_dim,
|
||||
embedding_dim=embedding_dim,
|
||||
)
|
||||
# MLP input = raw obs + history embedding.
|
||||
in_dim = raw_obs_dim + embedding_dim
|
||||
else:
|
||||
self.history_encoder = None
|
||||
in_dim = self.num_observations
|
||||
|
||||
layers: list[nn.Module] = []
|
||||
for hidden_size in hidden_sizes:
|
||||
layers.append(nn.Linear(in_dim, hidden_size))
|
||||
layers.append(nn.ELU())
|
||||
@@ -32,17 +90,29 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model):
|
||||
elif role == "value":
|
||||
return DeterministicMixin.act(self, inputs, role)
|
||||
|
||||
def _encode(self, states: torch.Tensor) -> torch.Tensor:
|
||||
"""Split augmented obs into current obs + history, encode, concat."""
|
||||
if self.history_encoder is None:
|
||||
return self.net(states)
|
||||
|
||||
obs = states[:, :self._raw_obs_dim]
|
||||
hist_flat = states[:, self._raw_obs_dim:]
|
||||
step_dim = self._raw_obs_dim + self.num_actions
|
||||
history = hist_flat.reshape(-1, self._history_length, step_dim)
|
||||
embedding = self.history_encoder(history)
|
||||
return self.net(torch.cat([obs, embedding], dim=-1))
|
||||
|
||||
def compute(
|
||||
self, inputs: dict[str, torch.Tensor], role: str = ""
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if role == "policy":
|
||||
self._shared_output = self.net(inputs["states"])
|
||||
self._shared_output = self._encode(inputs["states"])
|
||||
return self.mean_layer(self._shared_output), self.log_std_parameter, {}
|
||||
elif role == "value":
|
||||
shared_output = (
|
||||
self._shared_output
|
||||
if self._shared_output is not None
|
||||
else self.net(inputs["states"])
|
||||
else self._encode(inputs["states"])
|
||||
)
|
||||
self._shared_output = None
|
||||
return self.value_layer(shared_output), {}
|
||||
@@ -48,6 +48,10 @@ class TrainerConfig:
|
||||
record_video_every: int = 10_000 # 0 = disabled
|
||||
record_video_fps: int = 0 # 0 = derive from sim dt×substeps
|
||||
|
||||
# History encoder (RMA-style adaptation)
|
||||
history_length: int = 0 # 0 = disabled, >0 = temporal window size
|
||||
embedding_dim: int = 32 # history encoder output dimension
|
||||
|
||||
|
||||
# ── Video-recording trainer ──────────────────────────────────────────
|
||||
|
||||
@@ -173,6 +177,11 @@ class Trainer:
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Determine raw obs dim (without history augmentation).
|
||||
raw_obs_dim = 0
|
||||
if self.config.history_length > 0:
|
||||
raw_obs_dim = self.runner.env.observation_space.shape[0]
|
||||
|
||||
self.model = SharedMLP(
|
||||
observation_space=obs_space,
|
||||
action_space=act_space,
|
||||
@@ -181,6 +190,9 @@ class Trainer:
|
||||
initial_log_std=self.config.initial_log_std,
|
||||
min_log_std=self.config.min_log_std,
|
||||
max_log_std=self.config.max_log_std,
|
||||
history_length=self.config.history_length,
|
||||
raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=self.config.embedding_dim,
|
||||
)
|
||||
|
||||
models = {"policy": self.model, "value": self.model}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Unit tests for MuJoCoRunner domain randomization."""
|
||||
"""Unit tests for MuJoCoRunner domain randomization and history buffer."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
|
||||
from src.runners.mujoco import DomainRandConfig, MuJoCoRunnerConfig
|
||||
from src.models.mlp import SharedMLP, HistoryEncoder
|
||||
|
||||
|
||||
class TestDomainRandConfig:
|
||||
@@ -37,3 +40,55 @@ class TestMuJoCoRunnerConfig:
|
||||
assert isinstance(cfg.domain_rand, DomainRandConfig)
|
||||
assert cfg.domain_rand.mass_frac == 0.2
|
||||
assert cfg.domain_rand.friction_frac == 0.3
|
||||
|
||||
def test_history_length_default(self) -> None:
|
||||
cfg = MuJoCoRunnerConfig()
|
||||
assert cfg.history_length == 0
|
||||
|
||||
|
||||
class TestHistoryEncoder:
|
||||
def test_output_shape(self) -> None:
|
||||
enc = HistoryEncoder(history_length=10, step_dim=7, embedding_dim=32)
|
||||
x = torch.randn(4, 10, 7) # batch=4, H=10, step_dim=7
|
||||
out = enc(x)
|
||||
assert out.shape == (4, 32)
|
||||
|
||||
def test_different_embedding_dim(self) -> None:
|
||||
enc = HistoryEncoder(history_length=5, step_dim=7, embedding_dim=16)
|
||||
x = torch.randn(2, 5, 7)
|
||||
out = enc(x)
|
||||
assert out.shape == (2, 16)
|
||||
|
||||
|
||||
class TestSharedMLPWithHistory:
|
||||
def test_no_history(self) -> None:
|
||||
"""Without history, model works as before."""
|
||||
obs_space = spaces.Box(low=-1.0, high=1.0, shape=(6,))
|
||||
act_space = spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
model = SharedMLP(obs_space, act_space, torch.device("cpu"),
|
||||
hidden_sizes=(32, 32))
|
||||
assert model.history_encoder is None
|
||||
inp = {"states": torch.randn(4, 6)}
|
||||
mean, log_std, _ = model.compute(inp, role="policy")
|
||||
assert mean.shape == (4, 1)
|
||||
|
||||
def test_with_history(self) -> None:
|
||||
"""With history, model splits obs and encodes history."""
|
||||
raw_obs_dim = 6
|
||||
act_dim = 1
|
||||
H = 10
|
||||
step_dim = raw_obs_dim + act_dim # 7
|
||||
aug_dim = raw_obs_dim + H * step_dim # 6 + 70 = 76
|
||||
|
||||
obs_space = spaces.Box(low=-1.0, high=1.0, shape=(aug_dim,))
|
||||
act_space = spaces.Box(low=-1.0, high=1.0, shape=(act_dim,))
|
||||
model = SharedMLP(obs_space, act_space, torch.device("cpu"),
|
||||
hidden_sizes=(32, 32),
|
||||
history_length=H, raw_obs_dim=raw_obs_dim,
|
||||
embedding_dim=32)
|
||||
assert model.history_encoder is not None
|
||||
inp = {"states": torch.randn(4, aug_dim)}
|
||||
mean, log_std, _ = model.compute(inp, role="policy")
|
||||
assert mean.shape == (4, act_dim)
|
||||
value, _ = model.compute(inp, role="value")
|
||||
assert value.shape == (4, 1)
|
||||
|
||||
Reference in New Issue
Block a user