diff --git a/configs/runner/mjx.yaml b/configs/runner/mjx.yaml index 5c35dbf..4d6fc5a 100644 --- a/configs/runner/mjx.yaml +++ b/configs/runner/mjx.yaml @@ -2,3 +2,4 @@ num_envs: 1024 # MJX shines with many parallel envs device: auto # auto = cuda if available, else cpu dt: 0.002 substeps: 10 +history_length: 10 # RMA-style: 10-step window of (obs, action) pairs diff --git a/configs/runner/mujoco.yaml b/configs/runner/mujoco.yaml index 67702fc..0cb1e5d 100644 --- a/configs/runner/mujoco.yaml +++ b/configs/runner/mujoco.yaml @@ -2,6 +2,7 @@ num_envs: 64 device: auto # auto = cuda if available, else cpu dt: 0.002 substeps: 10 +history_length: 10 # RMA-style: 10-step window of (obs, action) pairs # ── Sim2real: domain randomization ─────────────────────────────── domain_rand: diff --git a/configs/runner/mujoco_single.yaml b/configs/runner/mujoco_single.yaml index 3b61919..d1b4e2c 100644 --- a/configs/runner/mujoco_single.yaml +++ b/configs/runner/mujoco_single.yaml @@ -5,3 +5,4 @@ num_envs: 1 device: cpu dt: 0.002 substeps: 10 +history_length: 10 diff --git a/configs/runner/serial.yaml b/configs/runner/serial.yaml index 44f19ff..69be7de 100644 --- a/configs/runner/serial.yaml +++ b/configs/runner/serial.yaml @@ -8,3 +8,4 @@ port: /dev/cu.usbserial-0001 baud: 115200 dt: 0.02 # control loop period (50 Hz, matches training) no_data_timeout: 2.0 # seconds of silence before declaring disconnect +history_length: 10 # must match training runner diff --git a/configs/training/ppo.yaml b/configs/training/ppo.yaml index 5167942..9d6a355 100644 --- a/configs/training/ppo.yaml +++ b/configs/training/ppo.yaml @@ -18,6 +18,10 @@ max_log_std: 2.0 record_video_every: 10000 +# RMA-style history encoder +history_length: 10 # temporal window (must match runner) +embedding_dim: 32 # history encoder output dimension + # ClearML remote execution (GPU worker) remote: false diff --git a/src/core/runner.py b/src/core/runner.py index 2b064e2..ff2b004 100644 --- a/src/core/runner.py +++ b/src/core/runner.py @@ -14,6 +14,7 @@ T = TypeVar("T") class BaseRunnerConfig: num_envs: int = 1 device: str = "cpu" + history_length: int = 0 # 0 = no history (single obs), >0 = RMA-style class BaseRunner(abc.ABC, Generic[T]): def __init__(self, env: BaseEnv, config: T) -> None: @@ -36,6 +37,26 @@ class BaseRunner(abc.ABC, Generic[T]): self.config.num_envs, dtype=torch.long, device=self.config.device ) + # ── History buffer (RMA-style adaptation) ──────────────── + self._history_len: int = getattr(self.config, "history_length", 0) + if self._history_len > 0: + obs_dim = self.observation_space.shape[0] + act_dim = self.action_space.shape[0] + self._history_obs_dim = obs_dim + self._history_act_dim = act_dim + self._history_step_dim = obs_dim + act_dim # each step stores (obs, action) + # Ring buffer: (num_envs, history_length, obs_dim + act_dim) + self._history_buf = torch.zeros( + self.config.num_envs, self._history_len, self._history_step_dim, + device=self.config.device, + ) + # Augmented observation space: [current_obs, history_flat] + from gymnasium import spaces + aug_dim = obs_dim + self._history_len * self._history_step_dim + self.observation_space = spaces.Box( + low=-torch.inf, high=torch.inf, shape=(aug_dim,) + ) + @property @abc.abstractmethod def num_envs(self) -> int: @@ -63,14 +84,44 @@ class BaseRunner(abc.ABC, Generic[T]): if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None: self._offscreen_renderer.close() + def _augment_obs(self, obs: torch.Tensor) -> torch.Tensor: + """Concatenate history buffer to current obs if history is enabled.""" + if self._history_len <= 0: + return obs + # Flatten history: (num_envs, H, step_dim) → (num_envs, H * step_dim) + hist_flat = self._history_buf.reshape(obs.shape[0], -1) + return torch.cat([obs, hist_flat], dim=-1) + + def _push_history(self, obs: torch.Tensor, actions: torch.Tensor, + env_ids: torch.Tensor | None = None) -> None: + """Push (obs, action) into the ring buffer (shift left, append right).""" + if self._history_len <= 0: + return + step = torch.cat([obs, actions.reshape(obs.shape[0], -1)], dim=-1) + if env_ids is None: + # All envs. + self._history_buf = torch.roll(self._history_buf, -1, dims=1) + self._history_buf[:, -1] = step + else: + self._history_buf[env_ids] = torch.roll( + self._history_buf[env_ids], -1, dims=1 + ) + self._history_buf[env_ids, -1] = step[env_ids] + + def _reset_history(self, env_ids: torch.Tensor) -> None: + """Zero the history buffer for reset envs.""" + if self._history_len > 0: + self._history_buf[env_ids] = 0.0 + def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: all_ids = torch.arange(self.num_envs, device=self.device) qpos, qvel = self._sim_reset(all_ids) self.step_counts.zero_() + self._reset_history(all_ids) state = self.env.build_state(qpos, qvel) obs = self.env.compute_observations(state) - return obs, {} + return self._augment_obs(obs), {} def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: self._last_actions = actions @@ -83,23 +134,27 @@ class BaseRunner(abc.ABC, Generic[T]): terminated = self.env.compute_terminations(state) truncated = self.env.compute_truncations(self.step_counts) + # Push current (obs, action) into history before augmenting. + self._push_history(obs, actions) + info: dict[str, Any] = {} done = terminated | truncated done_ids = done.nonzero(as_tuple=False).squeeze(-1) if done_ids.numel() > 0: - info["final_observations"] = obs[done_ids].clone() + info["final_observations"] = self._augment_obs(obs)[done_ids].clone() info["final_env_ids"] = done_ids.clone() reset_qpos, reset_qvel = self._sim_reset(done_ids) self.step_counts[done_ids] = 0 + self._reset_history(done_ids) reset_state = self.env.build_state(reset_qpos, reset_qvel) obs[done_ids] = self.env.compute_observations(reset_state) # skrl expects (num_envs, 1) for rewards/terminated/truncated - return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info + return self._augment_obs(obs), rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info def _render_frame(self, env_idx: int = 0) -> np.ndarray: """Return a raw RGB frame. Override in subclass.""" diff --git a/src/models/mlp.py b/src/models/mlp.py index 398c6a5..f4c7558 100644 --- a/src/models/mlp.py +++ b/src/models/mlp.py @@ -3,14 +3,72 @@ import torch.nn as nn from gymnasium import spaces from skrl.models.torch import Model, GaussianMixin, DeterministicMixin + +class HistoryEncoder(nn.Module): + """1D-CNN encoder over a temporal window of (obs, action) pairs. + + Input: (batch, history_length, step_dim) + Output: (batch, embedding_dim) + + Architecture: two temporal conv layers → global average pool → linear. + This lets the policy implicitly infer environment parameters (mass, + friction, gear, etc.) from recent dynamics — the core of RMA-style + adaptation for sim2real. + """ + + def __init__( + self, + history_length: int, + step_dim: int, + embedding_dim: int = 32, + hidden_channels: int = 32, + ) -> None: + super().__init__() + self.conv = nn.Sequential( + # (batch, step_dim, history_length) after transpose + nn.Conv1d(step_dim, hidden_channels, kernel_size=3, padding=1), + nn.ELU(), + nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1), + nn.ELU(), + ) + self.fc = nn.Linear(hidden_channels, embedding_dim) + + def forward(self, history: torch.Tensor) -> torch.Tensor: + """history: (batch, history_length, step_dim).""" + # Conv1d expects (batch, channels, seq_len). + x = history.transpose(1, 2) + x = self.conv(x) + # Global average pool over time. + x = x.mean(dim=-1) + return self.fc(x) + + class SharedMLP(GaussianMixin, DeterministicMixin, Model): - def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0): + def __init__(self, observation_space: spaces.Space, action_space: spaces.Space, device: torch.device, hidden_sizes: tuple[int, ...] = (32, 32), clip_actions: bool = False, clip_log_std: bool = True, min_log_std: float = -2.0, max_log_std: float = 2.0, initial_log_std: float = 0.0, history_length: int = 0, raw_obs_dim: int = 0, embedding_dim: int = 32): Model.__init__(self, observation_space, action_space, device) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std) DeterministicMixin.__init__(self, clip_actions) - layers = [] - in_dim: int = self.num_observations + self._history_length = history_length + self._raw_obs_dim = raw_obs_dim + self._embedding_dim = embedding_dim + + if history_length > 0 and raw_obs_dim > 0: + # The observation is [current_obs(raw_obs_dim), history_flat(H * step_dim)]. + act_dim = self.num_actions + step_dim = raw_obs_dim + act_dim + self.history_encoder = HistoryEncoder( + history_length=history_length, + step_dim=step_dim, + embedding_dim=embedding_dim, + ) + # MLP input = raw obs + history embedding. + in_dim = raw_obs_dim + embedding_dim + else: + self.history_encoder = None + in_dim = self.num_observations + + layers: list[nn.Module] = [] for hidden_size in hidden_sizes: layers.append(nn.Linear(in_dim, hidden_size)) layers.append(nn.ELU()) @@ -32,17 +90,29 @@ class SharedMLP(GaussianMixin, DeterministicMixin, Model): elif role == "value": return DeterministicMixin.act(self, inputs, role) + def _encode(self, states: torch.Tensor) -> torch.Tensor: + """Split augmented obs into current obs + history, encode, concat.""" + if self.history_encoder is None: + return self.net(states) + + obs = states[:, :self._raw_obs_dim] + hist_flat = states[:, self._raw_obs_dim:] + step_dim = self._raw_obs_dim + self.num_actions + history = hist_flat.reshape(-1, self._history_length, step_dim) + embedding = self.history_encoder(history) + return self.net(torch.cat([obs, embedding], dim=-1)) + def compute( self, inputs: dict[str, torch.Tensor], role: str = "" ) -> tuple[torch.Tensor, ...]: if role == "policy": - self._shared_output = self.net(inputs["states"]) + self._shared_output = self._encode(inputs["states"]) return self.mean_layer(self._shared_output), self.log_std_parameter, {} elif role == "value": shared_output = ( self._shared_output if self._shared_output is not None - else self.net(inputs["states"]) + else self._encode(inputs["states"]) ) self._shared_output = None return self.value_layer(shared_output), {} \ No newline at end of file diff --git a/src/training/trainer.py b/src/training/trainer.py index 87ddf75..889fbbf 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -48,6 +48,10 @@ class TrainerConfig: record_video_every: int = 10_000 # 0 = disabled record_video_fps: int = 0 # 0 = derive from sim dt×substeps + # History encoder (RMA-style adaptation) + history_length: int = 0 # 0 = disabled, >0 = temporal window size + embedding_dim: int = 32 # history encoder output dimension + # ── Video-recording trainer ────────────────────────────────────────── @@ -173,6 +177,11 @@ class Trainer: device=device, ) + # Determine raw obs dim (without history augmentation). + raw_obs_dim = 0 + if self.config.history_length > 0: + raw_obs_dim = self.runner.env.observation_space.shape[0] + self.model = SharedMLP( observation_space=obs_space, action_space=act_space, @@ -181,6 +190,9 @@ class Trainer: initial_log_std=self.config.initial_log_std, min_log_std=self.config.min_log_std, max_log_std=self.config.max_log_std, + history_length=self.config.history_length, + raw_obs_dim=raw_obs_dim, + embedding_dim=self.config.embedding_dim, ) models = {"policy": self.model, "value": self.model} diff --git a/tests/test_sim2real.py b/tests/test_sim2real.py index ff673e1..740dcad 100644 --- a/tests/test_sim2real.py +++ b/tests/test_sim2real.py @@ -1,11 +1,14 @@ -"""Unit tests for MuJoCoRunner domain randomization.""" +"""Unit tests for MuJoCoRunner domain randomization and history buffer.""" import dataclasses import numpy as np import pytest +import torch +from gymnasium import spaces from src.runners.mujoco import DomainRandConfig, MuJoCoRunnerConfig +from src.models.mlp import SharedMLP, HistoryEncoder class TestDomainRandConfig: @@ -37,3 +40,55 @@ class TestMuJoCoRunnerConfig: assert isinstance(cfg.domain_rand, DomainRandConfig) assert cfg.domain_rand.mass_frac == 0.2 assert cfg.domain_rand.friction_frac == 0.3 + + def test_history_length_default(self) -> None: + cfg = MuJoCoRunnerConfig() + assert cfg.history_length == 0 + + +class TestHistoryEncoder: + def test_output_shape(self) -> None: + enc = HistoryEncoder(history_length=10, step_dim=7, embedding_dim=32) + x = torch.randn(4, 10, 7) # batch=4, H=10, step_dim=7 + out = enc(x) + assert out.shape == (4, 32) + + def test_different_embedding_dim(self) -> None: + enc = HistoryEncoder(history_length=5, step_dim=7, embedding_dim=16) + x = torch.randn(2, 5, 7) + out = enc(x) + assert out.shape == (2, 16) + + +class TestSharedMLPWithHistory: + def test_no_history(self) -> None: + """Without history, model works as before.""" + obs_space = spaces.Box(low=-1.0, high=1.0, shape=(6,)) + act_space = spaces.Box(low=-1.0, high=1.0, shape=(1,)) + model = SharedMLP(obs_space, act_space, torch.device("cpu"), + hidden_sizes=(32, 32)) + assert model.history_encoder is None + inp = {"states": torch.randn(4, 6)} + mean, log_std, _ = model.compute(inp, role="policy") + assert mean.shape == (4, 1) + + def test_with_history(self) -> None: + """With history, model splits obs and encodes history.""" + raw_obs_dim = 6 + act_dim = 1 + H = 10 + step_dim = raw_obs_dim + act_dim # 7 + aug_dim = raw_obs_dim + H * step_dim # 6 + 70 = 76 + + obs_space = spaces.Box(low=-1.0, high=1.0, shape=(aug_dim,)) + act_space = spaces.Box(low=-1.0, high=1.0, shape=(act_dim,)) + model = SharedMLP(obs_space, act_space, torch.device("cpu"), + hidden_sizes=(32, 32), + history_length=H, raw_obs_dim=raw_obs_dim, + embedding_dim=32) + assert model.history_encoder is not None + inp = {"states": torch.randn(4, aug_dim)} + mean, log_std, _ = model.compute(inp, role="policy") + assert mean.shape == (4, act_dim) + value, _ = model.compute(inp, role="value") + assert value.shape == (4, 1)