Added a temporal observation history buffer and 1D-CNN encoder so the
policy can implicitly infer environment parameters (mass, friction,
gear ratios, etc.) from recent (obs, action) dynamics.
Architecture:
history window [(obs₀,a₀), ..., (obs_{H-1},a_{H-1})]
→ 1D-CNN HistoryEncoder → embedding (32-dim)
→ concat [current_obs, embedding] → MLP → action
Components:
- BaseRunner: history ring buffer, _push_history/_reset_history,
augmented obs space (6 + H×7 = 76 with H=10)
- HistoryEncoder (src/models/mlp.py): 2-layer temporal Conv1d + GAP
- SharedMLP: optional history_length/raw_obs_dim/embedding_dim params;
splits augmented obs, encodes history, feeds [obs, emb] to MLP
- TrainerConfig: history_length, embedding_dim fields
- All runner configs: history_length=10 by default
- Tests: encoder shape, model with/without history, config defaults
39 lines
1.1 KiB
YAML
39 lines
1.1 KiB
YAML
hidden_sizes: [128, 128]
|
|
total_timesteps: 5000000
|
|
rollout_steps: 1024
|
|
learning_epochs: 4
|
|
mini_batches: 4
|
|
discount_factor: 0.99
|
|
gae_lambda: 0.95
|
|
learning_rate: 0.0003
|
|
clip_ratio: 0.2
|
|
value_loss_scale: 0.5
|
|
entropy_loss_scale: 0.05
|
|
log_interval: 1000
|
|
checkpoint_interval: 50000
|
|
|
|
initial_log_std: 0.5
|
|
min_log_std: -2.0
|
|
max_log_std: 2.0
|
|
|
|
record_video_every: 10000
|
|
|
|
# RMA-style history encoder
|
|
history_length: 10 # temporal window (must match runner)
|
|
embedding_dim: 32 # history encoder output dimension
|
|
|
|
# ClearML remote execution (GPU worker)
|
|
remote: false
|
|
|
|
# ── HPO search ranges ────────────────────────────────────────────────
|
|
# Read by scripts/hpo.py — ignored by TrainerConfig during training.
|
|
hpo:
|
|
learning_rate: {min: 0.00005, max: 0.001}
|
|
clip_ratio: {min: 0.1, max: 0.3}
|
|
discount_factor: {min: 0.98, max: 0.999}
|
|
gae_lambda: {min: 0.9, max: 0.99}
|
|
entropy_loss_scale: {min: 0.0001, max: 0.1}
|
|
value_loss_scale: {min: 0.1, max: 1.0}
|
|
learning_epochs: {min: 2, max: 8, type: int}
|
|
mini_batches: {values: [2, 4, 8, 16]}
|