♻️ crazy refactor
This commit is contained in:
91
src/core/hardware.py
Normal file
91
src/core/hardware.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Real-hardware configuration — loaded from hardware.yaml next to robot.yaml.
|
||||
|
||||
Provides robot-specific constants for the SerialRunner: encoder specs,
|
||||
safety limits, and reset behaviour. Simulation-only robots simply don't
|
||||
have a hardware.yaml (the loader returns None).
|
||||
|
||||
Usage:
|
||||
hw = load_hardware_config("assets/rotary_cartpole")
|
||||
if hw is not None:
|
||||
counts_per_rev = hw.encoder.ppr * hw.encoder.gear_ratio * 4.0
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class EncoderConfig:
|
||||
"""Rotary encoder parameters."""
|
||||
|
||||
ppr: int = 11 # pulses per revolution (before quadrature)
|
||||
gear_ratio: float = 30.0 # gearbox ratio
|
||||
|
||||
@property
|
||||
def counts_per_rev(self) -> float:
|
||||
"""Total encoder counts per output-shaft revolution (quadrature)."""
|
||||
return self.ppr * self.gear_ratio * 4.0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SafetyConfig:
|
||||
"""Safety limits enforced by the runner (not the env)."""
|
||||
|
||||
max_motor_angle_deg: float = 90.0 # hard termination (0 = disabled)
|
||||
soft_limit_deg: float = 40.0 # progressive penalty ramp start
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ResetConfig:
|
||||
"""Parameters for the physical reset procedure."""
|
||||
|
||||
drive_speed: int = 80 # PWM for bang-bang drive-to-center
|
||||
deadband: int = 15 # encoder count threshold for "centered"
|
||||
drive_timeout: float = 3.0 # seconds
|
||||
|
||||
settle_angle_deg: float = 2.0 # pendulum angle threshold (degrees)
|
||||
settle_vel_dps: float = 5.0 # pendulum velocity threshold (deg/s)
|
||||
settle_duration: float = 0.5 # seconds the pendulum must stay still
|
||||
settle_timeout: float = 30.0 # give up after this many seconds
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HardwareConfig:
|
||||
"""Complete real-hardware description for a robot."""
|
||||
|
||||
encoder: EncoderConfig = dataclasses.field(default_factory=EncoderConfig)
|
||||
safety: SafetyConfig = dataclasses.field(default_factory=SafetyConfig)
|
||||
reset: ResetConfig = dataclasses.field(default_factory=ResetConfig)
|
||||
|
||||
|
||||
def load_hardware_config(robot_dir: str | Path) -> HardwareConfig | None:
|
||||
"""Load hardware.yaml from a directory.
|
||||
|
||||
Returns None if the file doesn't exist (simulation-only robot).
|
||||
"""
|
||||
robot_dir = Path(robot_dir).resolve()
|
||||
yaml_path = robot_dir / "hardware.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
return None
|
||||
|
||||
raw = yaml.safe_load(yaml_path.read_text()) or {}
|
||||
|
||||
encoder = EncoderConfig(**raw.get("encoder", {}))
|
||||
safety = SafetyConfig(**raw.get("safety", {}))
|
||||
reset = ResetConfig(**raw.get("reset", {}))
|
||||
|
||||
config = HardwareConfig(encoder=encoder, safety=safety, reset=reset)
|
||||
|
||||
log.debug(
|
||||
"hardware_config_loaded",
|
||||
robot_dir=str(robot_dir),
|
||||
counts_per_rev=encoder.counts_per_rev,
|
||||
max_motor_angle_deg=safety.max_motor_angle_deg,
|
||||
)
|
||||
return config
|
||||
@@ -21,6 +21,7 @@ class RotaryCartPoleConfig(BaseEnvConfig):
|
||||
"""
|
||||
# Reward shaping
|
||||
reward_upright_scale: float = 1.0 # cos(pendulum) when upright = +1
|
||||
speed_penalty_scale: float = 0.1 # penalty for high pendulum velocity near top
|
||||
|
||||
|
||||
class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
@@ -69,11 +70,12 @@ class RotaryCartPoleEnv(BaseEnv[RotaryCartPoleConfig]):
|
||||
# Upright reward: -cos(θ) ∈ [-1, +1]
|
||||
upright = -torch.cos(state.pendulum_angle)
|
||||
|
||||
# Velocity penalties — make spinning expensive but allow swing-up
|
||||
pend_vel_penalty = 0.01 * state.pendulum_vel ** 2
|
||||
motor_vel_penalty = 0.01 * state.motor_vel ** 2
|
||||
# Penalise high pendulum velocity when near the top (upright).
|
||||
# "nearness" is weighted by how upright the pendulum is (0 at bottom, 1 at top).
|
||||
near_top = torch.clamp(-torch.cos(state.pendulum_angle), min=0.0) # 0‥1
|
||||
speed_penalty = self.config.speed_penalty_scale * near_top * state.pendulum_vel.abs()
|
||||
|
||||
return upright - pend_vel_penalty - motor_vel_penalty
|
||||
return upright * self.config.reward_upright_scale - speed_penalty
|
||||
|
||||
def compute_terminations(self, state: RotaryCartPoleState) -> torch.Tensor:
|
||||
# No early termination — episode runs for max_steps (truncation only).
|
||||
|
||||
1
src/hpo/__init__.py
Normal file
1
src/hpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Hyperparameter optimization — SMAC3 + ClearML Successive Halving."""
|
||||
636
src/hpo/smac3.py
Normal file
636
src/hpo/smac3.py
Normal file
@@ -0,0 +1,636 @@
|
||||
# Requires: pip install smac==2.0.0 ConfigSpace==0.4.20
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from clearml import Task
|
||||
from clearml.automation.optimization import Objective, SearchStrategy
|
||||
from clearml.automation.parameters import Parameter
|
||||
from clearml.backend_interface.session import SendError
|
||||
from ConfigSpace import (
|
||||
CategoricalHyperparameter,
|
||||
ConfigurationSpace,
|
||||
UniformFloatHyperparameter,
|
||||
UniformIntegerHyperparameter,
|
||||
)
|
||||
from smac import MultiFidelityFacade
|
||||
from smac.intensifier.successive_halving import SuccessiveHalving
|
||||
from smac.runhistory.dataclasses import TrialInfo, TrialValue
|
||||
from smac.scenario import Scenario
|
||||
|
||||
|
||||
def retry_on_error(max_retries=5, initial_delay=2.0, backoff=2.0, exceptions=(Exception,)):
|
||||
"""Decorator to retry a function on exception with exponential backoff."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
delay = initial_delay
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions:
|
||||
if attempt == max_retries - 1:
|
||||
return None # Return None instead of raising
|
||||
time.sleep(delay)
|
||||
delay *= backoff
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _encode_param_name(name: str) -> str:
|
||||
"""Encode parameter name for ConfigSpace (replace / with __SLASH__)"""
|
||||
return name.replace("/", "__SLASH__")
|
||||
|
||||
|
||||
def _decode_param_name(name: str) -> str:
|
||||
"""Decode parameter name back to original (replace __SLASH__ with /)"""
|
||||
return name.replace("__SLASH__", "/")
|
||||
|
||||
|
||||
def _convert_param_to_cs(param: Parameter):
|
||||
"""
|
||||
Convert a ClearML Parameter into a ConfigSpace hyperparameter,
|
||||
adapted to ConfigSpace>=1.x (no more 'q' argument).
|
||||
"""
|
||||
# Encode the name to avoid ConfigSpace issues with special chars like '/'
|
||||
name = _encode_param_name(param.name)
|
||||
|
||||
# Categorical / discrete list
|
||||
if hasattr(param, "values"):
|
||||
return CategoricalHyperparameter(name=name, choices=list(param.values))
|
||||
|
||||
# Numeric range (float or int)
|
||||
if hasattr(param, "min_value") and hasattr(param, "max_value"):
|
||||
min_val = param.min_value
|
||||
max_val = param.max_value
|
||||
|
||||
# Check if this should be treated as integer
|
||||
if isinstance(min_val, int) and isinstance(max_val, int):
|
||||
log = getattr(param, "log_scale", False)
|
||||
|
||||
# Check for step_size for quantization
|
||||
if hasattr(param, "step_size"):
|
||||
sv = int(param.step_size)
|
||||
if sv != 1:
|
||||
# emulate quantization by explicit list of values
|
||||
choices = list(range(min_val, max_val + 1, sv))
|
||||
return CategoricalHyperparameter(name=name, choices=choices)
|
||||
|
||||
# Simple uniform integer range
|
||||
return UniformIntegerHyperparameter(name=name, lower=min_val, upper=max_val, log=log)
|
||||
else:
|
||||
# Treat as float
|
||||
lower, upper = float(min_val), float(max_val)
|
||||
log = getattr(param, "log_scale", False)
|
||||
return UniformFloatHyperparameter(name=name, lower=lower, upper=upper, log=log)
|
||||
|
||||
raise ValueError(f"Unsupported Parameter type: {type(param)}")
|
||||
|
||||
|
||||
class OptimizerSMAC(SearchStrategy):
|
||||
"""
|
||||
SMAC3-based hyperparameter optimizer, matching OptimizerBOHB interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_task_id: str,
|
||||
hyper_parameters: Sequence[Parameter],
|
||||
objective_metric: Objective,
|
||||
execution_queue: str,
|
||||
num_concurrent_workers: int,
|
||||
min_iteration_per_job: int,
|
||||
max_iteration_per_job: int,
|
||||
total_max_jobs: int,
|
||||
pool_period_min: float = 2.0,
|
||||
time_limit_per_job: float | None = None,
|
||||
compute_time_limit: float | None = None,
|
||||
**smac_kwargs: Any,
|
||||
):
|
||||
# Initialize base SearchStrategy
|
||||
super().__init__(
|
||||
base_task_id=base_task_id,
|
||||
hyper_parameters=hyper_parameters,
|
||||
objective_metric=objective_metric,
|
||||
execution_queue=execution_queue,
|
||||
num_concurrent_workers=num_concurrent_workers,
|
||||
pool_period_min=pool_period_min,
|
||||
time_limit_per_job=time_limit_per_job,
|
||||
compute_time_limit=compute_time_limit,
|
||||
min_iteration_per_job=min_iteration_per_job,
|
||||
max_iteration_per_job=max_iteration_per_job,
|
||||
total_max_jobs=total_max_jobs,
|
||||
)
|
||||
|
||||
# Expose for internal use (access private attributes from base class)
|
||||
self.execution_queue = self._execution_queue
|
||||
self.min_iterations = min_iteration_per_job
|
||||
self.max_iterations = max_iteration_per_job
|
||||
self.num_concurrent_workers = self._num_concurrent_workers # Fix: access private attribute
|
||||
|
||||
# Objective details
|
||||
# Handle both single objective (string) and multi-objective (list) cases
|
||||
if isinstance(self._objective_metric.title, list):
|
||||
self.metric_title = self._objective_metric.title[0] # Use first objective
|
||||
else:
|
||||
self.metric_title = self._objective_metric.title
|
||||
|
||||
if isinstance(self._objective_metric.series, list):
|
||||
self.metric_series = self._objective_metric.series[0] # Use first series
|
||||
else:
|
||||
self.metric_series = self._objective_metric.series
|
||||
|
||||
# ClearML Objective stores sign as a list, e.g., ['max'] or ['min']
|
||||
objective_sign = getattr(self._objective_metric, "sign", None) or getattr(self._objective_metric, "order", None)
|
||||
|
||||
# Handle list case - extract first element
|
||||
if isinstance(objective_sign, list):
|
||||
objective_sign = objective_sign[0] if objective_sign else "max"
|
||||
|
||||
# Default to max if nothing found
|
||||
if objective_sign is None:
|
||||
objective_sign = "max"
|
||||
|
||||
self.maximize_metric = str(objective_sign).lower() in ("max", "max_global")
|
||||
|
||||
# Build ConfigSpace
|
||||
self.config_space = ConfigurationSpace(seed=42)
|
||||
for p in self._hyper_parameters: # Access private attribute correctly
|
||||
cs_hp = _convert_param_to_cs(p)
|
||||
self.config_space.add(cs_hp)
|
||||
|
||||
# Configure SMAC Scenario
|
||||
scenario = Scenario(
|
||||
configspace=self.config_space,
|
||||
n_trials=self.total_max_jobs,
|
||||
min_budget=float(self.min_iterations),
|
||||
max_budget=float(self.max_iterations),
|
||||
walltime_limit=(self.compute_time_limit * 60) if self.compute_time_limit else None,
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
# build the Successive Halving intensifier (NOT Hyperband!)
|
||||
# Hyperband runs multiple brackets with different starting budgets - wasteful
|
||||
# Successive Halving: ALL configs start at min_budget, only best get promoted
|
||||
# eta controls the reduction factor (default 3 means keep top 1/3 each round)
|
||||
# eta can be overridden via smac_kwargs from HyperParameterOptimizer
|
||||
eta = smac_kwargs.pop("eta", 3) # Default to 3 if not specified
|
||||
intensifier = SuccessiveHalving(scenario=scenario, eta=eta, **smac_kwargs)
|
||||
|
||||
# now pass that intensifier instance into the facade
|
||||
self.smac = MultiFidelityFacade(
|
||||
scenario=scenario,
|
||||
target_function=lambda config, budget, seed: 0.0,
|
||||
intensifier=intensifier,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Bookkeeping
|
||||
self.running_tasks = {} # task_id -> trial info
|
||||
self.task_start_times = {} # task_id -> start time (for timeout)
|
||||
self.completed_results = []
|
||||
self.best_score_so_far = float("-inf") if self.maximize_metric else float("inf")
|
||||
self.time_limit_per_job = time_limit_per_job # Store time limit (minutes)
|
||||
|
||||
# Checkpoint continuation tracking: config_key -> {budget: task_id}
|
||||
# Used to find the previous task's checkpoint when promoting a config
|
||||
self.config_to_tasks = {} # config_key -> {budget: task_id}
|
||||
|
||||
# Manual Successive Halving control
|
||||
self.eta = eta
|
||||
self.current_budget = float(self.min_iterations)
|
||||
self.configs_at_budget = {} # budget -> list of (config, score, trial)
|
||||
self.pending_configs = [] # configs waiting to be evaluated at current_budget - list of (trial, prev_task_id)
|
||||
self.evaluated_at_budget = [] # (config, score, trial, task_id) for current budget
|
||||
self.smac_asked_configs = set() # track which configs SMAC has given us
|
||||
|
||||
# Calculate initial rung size for proper Successive Halving
|
||||
# With eta=3: rung sizes are n, n/3, n/9, ...
|
||||
# Total trials = n * (1 + 1/eta + 1/eta^2 + ...) = n * eta/(eta-1) for infinite series
|
||||
# For finite rungs, calculate exactly
|
||||
num_rungs = 1
|
||||
b = float(self.min_iterations)
|
||||
while b * eta <= self.max_iterations:
|
||||
num_rungs += 1
|
||||
b *= eta
|
||||
|
||||
# Sum of geometric series: 1 + 1/eta + 1/eta^2 + ... (num_rungs terms)
|
||||
series_sum = sum(1.0 / (eta**i) for i in range(num_rungs))
|
||||
self.initial_rung_size = int(self.total_max_jobs / series_sum)
|
||||
self.initial_rung_size = max(self.initial_rung_size, self.num_concurrent_workers) # at least num_workers
|
||||
self.configs_needed_for_rung = self.initial_rung_size # how many configs we still need for current rung
|
||||
self.rung_closed = False # whether we've collected all configs for current rung
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||
def _get_task_safe(self, task_id: str):
|
||||
"""Safely get a task with retry logic."""
|
||||
return Task.get_task(task_id=task_id)
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=5.0, exceptions=(ValueError, SendError, ConnectionError))
|
||||
def _launch_task(self, config: dict, budget: float, prev_task_id: str | None = None):
|
||||
"""Launch a task with retry logic for robustness.
|
||||
|
||||
Args:
|
||||
config: Hyperparameter configuration dict
|
||||
budget: Number of epochs to train
|
||||
prev_task_id: Optional task ID from previous budget to continue from (checkpoint)
|
||||
"""
|
||||
base = self._get_task_safe(task_id=self._base_task_id)
|
||||
if base is None:
|
||||
return None
|
||||
|
||||
clone = Task.clone(
|
||||
source_task=base,
|
||||
name=f"HPO Trial - {base.name}",
|
||||
parent=Task.current_task().id, # Set the current HPO task as parent
|
||||
)
|
||||
# Override hyperparameters
|
||||
for k, v in config.items():
|
||||
# Decode parameter name back to original (with slashes)
|
||||
original_name = _decode_param_name(k)
|
||||
# Convert numpy types to Python built-in types
|
||||
if hasattr(v, "item"): # numpy scalar
|
||||
param_value = v.item()
|
||||
elif isinstance(v, int | float | str | bool):
|
||||
param_value = type(v)(v) # Ensure it's the built-in type
|
||||
else:
|
||||
param_value = v
|
||||
clone.set_parameter(original_name, param_value)
|
||||
# Override epochs budget if multi-fidelity
|
||||
if self.max_iterations != self.min_iterations:
|
||||
clone.set_parameter("Hydra/training.max_epochs", int(budget))
|
||||
else:
|
||||
clone.set_parameter("Hydra/training.max_epochs", int(self.max_iterations))
|
||||
|
||||
# If we have a previous task, pass its ID so the worker can download the checkpoint
|
||||
if prev_task_id:
|
||||
clone.set_parameter("Hydra/training.resume_from_task_id", prev_task_id)
|
||||
|
||||
Task.enqueue(task=clone, queue_name=self.execution_queue)
|
||||
# Track start time for timeout enforcement
|
||||
self.task_start_times[clone.id] = time.time()
|
||||
return clone
|
||||
|
||||
def start(self):
|
||||
controller = Task.current_task()
|
||||
total_launched = 0
|
||||
|
||||
# Keep launching & collecting until budget exhausted
|
||||
while total_launched < self.total_max_jobs:
|
||||
# Check if current budget rung is complete BEFORE asking for new trials
|
||||
# (no running tasks, no pending configs, and we have results for this budget)
|
||||
if not self.running_tasks and not self.pending_configs and self.evaluated_at_budget:
|
||||
# Rung complete! Promote top performers to next budget
|
||||
|
||||
# Store results for this budget
|
||||
self.configs_at_budget[self.current_budget] = self.evaluated_at_budget.copy()
|
||||
|
||||
# Sort by score (best first)
|
||||
sorted_configs = sorted(
|
||||
self.evaluated_at_budget,
|
||||
key=lambda x: x[1], # score
|
||||
reverse=self.maximize_metric,
|
||||
)
|
||||
|
||||
# Print rung results
|
||||
for _i, (_cfg, _score, _tri, _task_id) in enumerate(sorted_configs[:5], 1):
|
||||
pass
|
||||
|
||||
# Move to next budget?
|
||||
next_budget = self.current_budget * self.eta
|
||||
if next_budget <= self.max_iterations:
|
||||
# How many to promote (top 1/eta)
|
||||
n_promote = max(1, len(sorted_configs) // self.eta)
|
||||
promoted = sorted_configs[:n_promote]
|
||||
|
||||
# Update budget and reset for next rung
|
||||
self.current_budget = next_budget
|
||||
self.evaluated_at_budget = []
|
||||
self.configs_needed_for_rung = 0 # promoted configs are all we need
|
||||
self.rung_closed = True # rung is pre-filled with promoted configs
|
||||
|
||||
# Re-queue promoted configs with new budget
|
||||
# Include the previous task ID for checkpoint continuation
|
||||
for _cfg, _score, old_trial, prev_task_id in promoted:
|
||||
new_trial = TrialInfo(
|
||||
config=old_trial.config,
|
||||
instance=old_trial.instance,
|
||||
seed=old_trial.seed,
|
||||
budget=self.current_budget,
|
||||
)
|
||||
# Store as tuple: (trial, prev_task_id)
|
||||
self.pending_configs.append((new_trial, prev_task_id))
|
||||
else:
|
||||
# All budgets complete
|
||||
break
|
||||
|
||||
# Fill pending_configs with new trials ONLY if we haven't closed this rung yet
|
||||
# For the first rung: ask SMAC for initial_rung_size configs total
|
||||
# For subsequent rungs: only use promoted configs (rung is already closed)
|
||||
while (
|
||||
not self.rung_closed
|
||||
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||
< self.initial_rung_size
|
||||
and total_launched < self.total_max_jobs
|
||||
):
|
||||
trial = self.smac.ask()
|
||||
if trial is None:
|
||||
self.rung_closed = True
|
||||
break
|
||||
# Create new trial with forced budget (TrialInfo is frozen, can't modify)
|
||||
trial_with_budget = TrialInfo(
|
||||
config=trial.config,
|
||||
instance=trial.instance,
|
||||
seed=trial.seed,
|
||||
budget=self.current_budget,
|
||||
)
|
||||
cfg_key = str(sorted(trial.config.items()))
|
||||
if cfg_key not in self.smac_asked_configs:
|
||||
self.smac_asked_configs.add(cfg_key)
|
||||
# Store as tuple: (trial, None) - no previous task for new configs
|
||||
self.pending_configs.append((trial_with_budget, None))
|
||||
|
||||
# Check if we've collected enough configs for this rung
|
||||
if (
|
||||
not self.rung_closed
|
||||
and len(self.pending_configs) + len(self.evaluated_at_budget) + len(self.running_tasks)
|
||||
>= self.initial_rung_size
|
||||
):
|
||||
self.rung_closed = True
|
||||
|
||||
# Launch pending configs up to concurrent limit
|
||||
while self.pending_configs and len(self.running_tasks) < self.num_concurrent_workers:
|
||||
# Unpack tuple: (trial, prev_task_id)
|
||||
trial, prev_task_id = self.pending_configs.pop(0)
|
||||
t = self._launch_task(trial.config, self.current_budget, prev_task_id=prev_task_id)
|
||||
if t is None:
|
||||
# Launch failed, mark trial as failed and continue
|
||||
# Tell SMAC this trial failed with worst possible score
|
||||
cost = float("inf") if self.maximize_metric else float("-inf")
|
||||
self.smac.tell(trial, TrialValue(cost=cost))
|
||||
total_launched += 1
|
||||
continue
|
||||
self.running_tasks[t.id] = trial
|
||||
|
||||
# Track which task ID was used for this config at this budget
|
||||
cfg_key = str(sorted(trial.config.items()))
|
||||
if cfg_key not in self.config_to_tasks:
|
||||
self.config_to_tasks[cfg_key] = {}
|
||||
self.config_to_tasks[cfg_key][self.current_budget] = t.id
|
||||
|
||||
total_launched += 1
|
||||
|
||||
if not self.running_tasks and not self.pending_configs:
|
||||
break
|
||||
|
||||
# Poll for finished or timed out
|
||||
done = []
|
||||
timed_out = []
|
||||
failed_to_check = []
|
||||
for tid, _tri in self.running_tasks.items():
|
||||
try:
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
if task is None:
|
||||
failed_to_check.append(tid)
|
||||
continue
|
||||
|
||||
st = task.get_status()
|
||||
|
||||
# Check if task completed normally
|
||||
if st == Task.TaskStatusEnum.completed or st in (
|
||||
Task.TaskStatusEnum.failed,
|
||||
Task.TaskStatusEnum.stopped,
|
||||
):
|
||||
done.append(tid)
|
||||
# Check for timeout (if time limit is set)
|
||||
elif self.time_limit_per_job and tid in self.task_start_times:
|
||||
elapsed_minutes = (time.time() - self.task_start_times[tid]) / 60.0
|
||||
if elapsed_minutes > self.time_limit_per_job:
|
||||
with contextlib.suppress(Exception):
|
||||
task.mark_stopped(force=True)
|
||||
timed_out.append(tid)
|
||||
except Exception:
|
||||
# Don't mark as failed immediately, might be transient
|
||||
# Only mark failed after multiple consecutive failures
|
||||
if not hasattr(self, "_task_check_failures"):
|
||||
self._task_check_failures = {}
|
||||
self._task_check_failures[tid] = self._task_check_failures.get(tid, 0) + 1
|
||||
if self._task_check_failures[tid] >= 5: # 5 consecutive failures
|
||||
failed_to_check.append(tid)
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
# Process tasks that failed to check
|
||||
for tid in failed_to_check:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
# Tell SMAC this trial failed with worst possible score
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
"failed": True,
|
||||
}
|
||||
)
|
||||
# Store result with task_id for checkpoint tracking
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
# Process completed tasks
|
||||
for tid in done:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
|
||||
# Clear any accumulated failures for this task
|
||||
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
if task is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
else:
|
||||
res = self._get_objective(task)
|
||||
|
||||
if res is None or res == float("-inf") or res == float("inf"):
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
}
|
||||
)
|
||||
|
||||
# Store result for this budget rung with task_id for checkpoint tracking
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
iteration = len(self.completed_results)
|
||||
|
||||
# Always report the trial score (even if it's bad)
|
||||
if res is not None and res != float("-inf") and res != float("inf"):
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization", series="trial_score", value=res, iteration=iteration
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization",
|
||||
series="trial_budget",
|
||||
value=tri.budget or self.max_iterations,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
# Update best score tracking based on actual results
|
||||
if res is not None and res != float("-inf") and res != float("inf"):
|
||||
if self.maximize_metric:
|
||||
self.best_score_so_far = max(self.best_score_so_far, res)
|
||||
elif res < self.best_score_so_far:
|
||||
self.best_score_so_far = res
|
||||
|
||||
# Always report best score so far (shows flat line when no improvement)
|
||||
if self.best_score_so_far != float("-inf") and self.best_score_so_far != float("inf"):
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization", series="best_score", value=self.best_score_so_far, iteration=iteration
|
||||
)
|
||||
|
||||
# Report running statistics
|
||||
valid_scores = [
|
||||
r["value"]
|
||||
for r in self.completed_results
|
||||
if r["value"] is not None and r["value"] != float("-inf") and r["value"] != float("inf")
|
||||
]
|
||||
if valid_scores:
|
||||
controller.get_logger().report_scalar(
|
||||
title="Optimization",
|
||||
series="mean_score",
|
||||
value=sum(valid_scores) / len(valid_scores),
|
||||
iteration=iteration,
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Progress",
|
||||
series="completed_trials",
|
||||
value=len(self.completed_results),
|
||||
iteration=iteration,
|
||||
)
|
||||
controller.get_logger().report_scalar(
|
||||
title="Progress", series="running_tasks", value=len(self.running_tasks), iteration=iteration
|
||||
)
|
||||
|
||||
# Process timed out tasks (treat as failed with current objective value)
|
||||
for tid in timed_out:
|
||||
tri = self.running_tasks.pop(tid)
|
||||
if tid in self.task_start_times:
|
||||
del self.task_start_times[tid]
|
||||
|
||||
# Clear any accumulated failures for this task
|
||||
if hasattr(self, "_task_check_failures") and tid in self._task_check_failures:
|
||||
del self._task_check_failures[tid]
|
||||
|
||||
# Try to get the last objective value before timeout
|
||||
task = self._get_task_safe(task_id=tid)
|
||||
if task is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
else:
|
||||
res = self._get_objective(task)
|
||||
|
||||
if res is None:
|
||||
res = float("-inf") if self.maximize_metric else float("inf")
|
||||
cost = -res if self.maximize_metric else res
|
||||
self.smac.tell(tri, TrialValue(cost=cost))
|
||||
self.completed_results.append(
|
||||
{
|
||||
"task_id": tid,
|
||||
"config": tri.config,
|
||||
"budget": tri.budget,
|
||||
"value": res,
|
||||
"timed_out": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Store timed out result for this budget rung with task_id
|
||||
self.evaluated_at_budget.append((tri.config, res, tri, tid))
|
||||
|
||||
time.sleep(self.pool_period_minutes * 60) # Fix: use correct attribute name from base class
|
||||
if self.compute_time_limit and controller.get_runtime() > self.compute_time_limit * 60:
|
||||
break
|
||||
|
||||
# Finalize
|
||||
self._finalize()
|
||||
return self.completed_results
|
||||
|
||||
@retry_on_error(max_retries=3, initial_delay=2.0, exceptions=(SendError, ConnectionError, KeyError))
|
||||
def _get_objective(self, task: Task):
|
||||
"""Get objective metric value with retry logic for robustness."""
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
m = task.get_last_scalar_metrics()
|
||||
if not m:
|
||||
return None
|
||||
|
||||
metric_data = m[self.metric_title][self.metric_series]
|
||||
|
||||
# ClearML returns dict with 'last', 'min', 'max' keys representing
|
||||
# the last/min/max values of this series over ALL logged iterations.
|
||||
# For snake_length/train_max: 'last' is the last logged train_max value,
|
||||
# 'max' is the highest train_max ever logged during training.
|
||||
|
||||
# Use 'max' if maximizing (we want the best performance achieved),
|
||||
# 'min' if minimizing, fallback to 'last'
|
||||
if self.maximize_metric and "max" in metric_data:
|
||||
result = metric_data["max"]
|
||||
elif not self.maximize_metric and "min" in metric_data:
|
||||
result = metric_data["min"]
|
||||
else:
|
||||
result = metric_data["last"]
|
||||
return result
|
||||
except (KeyError, Exception):
|
||||
return None
|
||||
|
||||
def _finalize(self):
|
||||
controller = Task.current_task()
|
||||
# Report final best score
|
||||
controller.get_logger().report_text(f"Final best score: {self.best_score_so_far}")
|
||||
|
||||
# Also try to get SMAC's incumbent for comparison
|
||||
try:
|
||||
incumbent = self.smac.intensifier.get_incumbent()
|
||||
if incumbent is not None:
|
||||
runhistory = self.smac.runhistory
|
||||
# Try different ways to get the cost
|
||||
incumbent_cost = None
|
||||
try:
|
||||
incumbent_cost = runhistory.get_cost(incumbent)
|
||||
except Exception:
|
||||
# Fallback: search through runhistory manually
|
||||
for trial_key, trial_value in runhistory.items():
|
||||
trial_config = runhistory.get_config(trial_key.config_id)
|
||||
if trial_config == incumbent and (incumbent_cost is None or trial_value.cost < incumbent_cost):
|
||||
incumbent_cost = trial_value.cost
|
||||
|
||||
if incumbent_cost is not None:
|
||||
score = -incumbent_cost if self.maximize_metric else incumbent_cost
|
||||
controller.get_logger().report_text(f"SMAC incumbent: {incumbent}, score: {score}")
|
||||
controller.upload_artifact(
|
||||
"best_config",
|
||||
{"config": dict(incumbent), "score": score, "our_best_score": self.best_score_so_far},
|
||||
)
|
||||
else:
|
||||
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||
except Exception as e:
|
||||
controller.get_logger().report_text(f"Error getting SMAC incumbent: {e}")
|
||||
controller.upload_artifact("best_config", {"our_best_score": self.best_score_so_far})
|
||||
@@ -214,6 +214,7 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
"""Offscreen render — copies one env's state from GPU to CPU."""
|
||||
self._render_data.qpos[:] = np.asarray(self._batch_data.qpos[env_idx])
|
||||
self._render_data.qvel[:] = np.asarray(self._batch_data.qvel[env_idx])
|
||||
self._render_data.ctrl[:] = np.asarray(self._batch_data.ctrl[env_idx])
|
||||
mujoco.mj_forward(self._mj_model, self._render_data)
|
||||
|
||||
if not hasattr(self, "_offscreen_renderer"):
|
||||
@@ -221,4 +222,10 @@ class MJXRunner(BaseRunner[MJXRunnerConfig]):
|
||||
self._mj_model, width=640, height=480,
|
||||
)
|
||||
self._offscreen_renderer.update_scene(self._render_data)
|
||||
return self._offscreen_renderer.render()
|
||||
frame = self._offscreen_renderer.render().copy()
|
||||
|
||||
# Import shared overlay helper from mujoco runner
|
||||
from src.runners.mujoco import _draw_action_overlay
|
||||
ctrl_val = float(self._render_data.ctrl[0]) if self._mj_model.nu > 0 else 0.0
|
||||
_draw_action_overlay(frame, ctrl_val)
|
||||
return frame
|
||||
|
||||
@@ -283,4 +283,43 @@ class MuJoCoRunner(BaseRunner[MuJoCoRunnerConfig]):
|
||||
)
|
||||
mujoco.mj_forward(self._model, self._data[env_idx])
|
||||
self._offscreen_renderer.update_scene(self._data[env_idx])
|
||||
return self._offscreen_renderer.render()
|
||||
frame = self._offscreen_renderer.render().copy()
|
||||
|
||||
# Draw action bar overlay — shows ctrl[0] as a horizontal bar
|
||||
ctrl_val = float(self._data[env_idx].ctrl[0]) if self._model.nu > 0 else 0.0
|
||||
_draw_action_overlay(frame, ctrl_val)
|
||||
return frame
|
||||
|
||||
|
||||
def _draw_action_overlay(frame: np.ndarray, action: float) -> None:
|
||||
"""Draw an action bar + text on a rendered frame (no OpenCV needed).
|
||||
|
||||
Bar is centered horizontally: green to the right (+), red to the left (-).
|
||||
"""
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Bar geometry
|
||||
bar_y = h - 30
|
||||
bar_h = 16
|
||||
bar_x_center = w // 2
|
||||
bar_half_w = w // 4 # max half-width of the bar
|
||||
bar_x_left = bar_x_center - bar_half_w
|
||||
bar_x_right = bar_x_center + bar_half_w
|
||||
|
||||
# Background (dark grey)
|
||||
frame[bar_y:bar_y + bar_h, bar_x_left:bar_x_right] = [40, 40, 40]
|
||||
|
||||
# Filled bar
|
||||
fill_len = int(abs(action) * bar_half_w)
|
||||
if action > 0:
|
||||
color = [60, 200, 60] # green
|
||||
x0 = bar_x_center
|
||||
x1 = min(bar_x_center + fill_len, bar_x_right)
|
||||
else:
|
||||
color = [200, 60, 60] # red
|
||||
x1 = bar_x_center
|
||||
x0 = max(bar_x_center - fill_len, bar_x_left)
|
||||
frame[bar_y:bar_y + bar_h, x0:x1] = color
|
||||
|
||||
# Center tick mark (white)
|
||||
frame[bar_y:bar_y + bar_h, bar_x_center - 1:bar_x_center + 1] = [255, 255, 255]
|
||||
571
src/runners/serial.py
Normal file
571
src/runners/serial.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""Serial runner — real hardware over USB/serial (ESP32).
|
||||
|
||||
Implements the BaseRunner interface for a single physical robot.
|
||||
All physics come from the real world; the runner translates between
|
||||
the ESP32 serial protocol and the qpos/qvel tensors that BaseRunner
|
||||
and BaseEnv expect.
|
||||
|
||||
Serial protocol (ESP32 firmware):
|
||||
Commands sent TO the ESP32:
|
||||
G — start streaming state lines
|
||||
H — stop streaming
|
||||
M<int> — set motor PWM speed (-255 … 255)
|
||||
|
||||
State lines received FROM the ESP32:
|
||||
S,<ms>,<enc>,<rpm>,<motor_speed>,<at_limit>,
|
||||
<pend_deg>,<pend_vel>,<target_speed>,<braking>,
|
||||
<enc_vel_cps>,<pendulum_ok>
|
||||
(12 comma-separated fields after the ``S`` prefix)
|
||||
|
||||
A daemon thread continuously reads the serial stream so the control
|
||||
loop never blocks on I/O.
|
||||
|
||||
Usage:
|
||||
python train.py env=rotary_cartpole runner=serial training=ppo_real
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.env import BaseEnv
|
||||
from src.core.hardware import HardwareConfig, load_hardware_config
|
||||
from src.core.runner import BaseRunner, BaseRunnerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SerialRunnerConfig(BaseRunnerConfig):
|
||||
"""Configuration for serial communication with the ESP32."""
|
||||
|
||||
num_envs: int = 1 # always 1 — single physical robot
|
||||
device: str = "cpu"
|
||||
|
||||
port: str = "/dev/cu.usbserial-0001"
|
||||
baud: int = 115200
|
||||
dt: float = 0.02 # control loop period (seconds), 50 Hz
|
||||
no_data_timeout: float = 2.0 # seconds of silence → disconnect
|
||||
encoder_jump_threshold: int = 200 # encoder tick jump → reboot
|
||||
|
||||
|
||||
class SerialRunner(BaseRunner[SerialRunnerConfig]):
|
||||
"""BaseRunner implementation that talks to real hardware over serial.
|
||||
|
||||
Maps the ESP32 serial protocol to qpos/qvel tensors so the existing
|
||||
RotaryCartPoleEnv (or any compatible env) works unchanged.
|
||||
|
||||
qpos layout: [motor_angle_rad, pendulum_angle_rad]
|
||||
qvel layout: [motor_vel_rad_s, pendulum_vel_rad_s]
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseRunner interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def num_envs(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _sim_initialize(self, config: SerialRunnerConfig) -> None:
|
||||
# Load hardware description (encoder, safety, reset params).
|
||||
hw = load_hardware_config(self.env.config.robot_path)
|
||||
if hw is None:
|
||||
raise FileNotFoundError(
|
||||
f"hardware.yaml not found in {self.env.config.robot_path}. "
|
||||
"The serial runner requires a hardware config for encoder, "
|
||||
"safety, and reset parameters."
|
||||
)
|
||||
self._hw: HardwareConfig = hw
|
||||
self._counts_per_rev: float = hw.encoder.counts_per_rev
|
||||
self._max_motor_angle_rad: float = (
|
||||
math.radians(hw.safety.max_motor_angle_deg)
|
||||
if hw.safety.max_motor_angle_deg > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# Joint dimensions for the rotary cartpole (motor + pendulum).
|
||||
self._nq = 2
|
||||
self._nv = 2
|
||||
|
||||
# Import serial here so it's not a hard dependency for sim-only users.
|
||||
import serial as _serial
|
||||
|
||||
self._serial_mod = _serial
|
||||
|
||||
self.ser: _serial.Serial = _serial.Serial(
|
||||
config.port, config.baud, timeout=0.05
|
||||
)
|
||||
time.sleep(2) # Wait for ESP32 boot.
|
||||
self.ser.reset_input_buffer()
|
||||
|
||||
# Internal state tracking.
|
||||
self._rebooted: bool = False
|
||||
self._serial_disconnected: bool = False
|
||||
self._last_esp_ms: int = 0
|
||||
self._last_data_time: float = time.monotonic()
|
||||
self._last_encoder_count: int = 0
|
||||
self._streaming: bool = False
|
||||
|
||||
# Latest parsed state (updated by the reader thread).
|
||||
self._latest_state: dict[str, Any] = {
|
||||
"timestamp_ms": 0,
|
||||
"encoder_count": 0,
|
||||
"rpm": 0.0,
|
||||
"motor_speed": 0,
|
||||
"at_limit": False,
|
||||
"pendulum_angle": 0.0,
|
||||
"pendulum_velocity": 0.0,
|
||||
"target_speed": 0,
|
||||
"braking": False,
|
||||
"enc_vel_cps": 0.0,
|
||||
"pendulum_ok": False,
|
||||
}
|
||||
self._state_lock = threading.Lock()
|
||||
self._state_event = threading.Event()
|
||||
|
||||
# Start background serial reader.
|
||||
self._reader_running = True
|
||||
self._reader_thread = threading.Thread(
|
||||
target=self._serial_reader, daemon=True
|
||||
)
|
||||
self._reader_thread.start()
|
||||
|
||||
# Start streaming.
|
||||
self._send("G")
|
||||
self._streaming = True
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
# Track wall-clock time of last step for PPO-gap detection.
|
||||
self._last_step_time: float = time.monotonic()
|
||||
|
||||
def _sim_step(
|
||||
self, actions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
now = time.monotonic()
|
||||
|
||||
# Detect PPO update gap: if more than 0.5s since last step,
|
||||
# the optimizer was running and no motor commands were sent.
|
||||
# Trigger a full reset so the robot starts from a clean state.
|
||||
gap = now - self._last_step_time
|
||||
if gap > 0.5:
|
||||
logger.info(
|
||||
"PPO update gap detected (%.1f s) — resetting before resuming.",
|
||||
gap,
|
||||
)
|
||||
self._send("M0")
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
|
||||
step_start = time.monotonic()
|
||||
|
||||
# Map normalised action [-1, 1] → PWM [-255, 255].
|
||||
action_val = float(actions[0, 0].clamp(-1.0, 1.0))
|
||||
motor_speed = int(action_val * 255)
|
||||
self._send(f"M{motor_speed}")
|
||||
|
||||
# Enforce dt wall-clock timing.
|
||||
elapsed = time.monotonic() - step_start
|
||||
remaining = self.config.dt - elapsed
|
||||
if remaining > 0:
|
||||
time.sleep(remaining)
|
||||
|
||||
# Read latest sensor data (non-blocking — dt sleep ensures freshness).
|
||||
state = self._read_state()
|
||||
|
||||
motor_angle = (
|
||||
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
motor_vel = (
|
||||
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||
|
||||
# Cache motor angle for safety check in step() — avoids a second read.
|
||||
self._last_motor_angle_rad = motor_angle
|
||||
self._last_step_time = time.monotonic()
|
||||
|
||||
qpos = torch.tensor(
|
||||
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||
)
|
||||
qvel = torch.tensor(
|
||||
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||
)
|
||||
return qpos, qvel
|
||||
|
||||
def _sim_reset(
|
||||
self, env_ids: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# If ESP32 rebooted or disconnected, we can't recover.
|
||||
if self._rebooted or self._serial_disconnected:
|
||||
raise RuntimeError(
|
||||
"ESP32 rebooted or disconnected during training! "
|
||||
"Encoder center is lost. "
|
||||
"Please re-center the motor manually and restart."
|
||||
)
|
||||
|
||||
# Stop motor and restart streaming.
|
||||
self._send("M0")
|
||||
self._send("H")
|
||||
self._streaming = False
|
||||
time.sleep(0.05)
|
||||
self._state_event.clear()
|
||||
self._send("G")
|
||||
self._streaming = True
|
||||
self._last_data_time = time.monotonic()
|
||||
time.sleep(0.05)
|
||||
|
||||
# Physically return the motor to the centre position.
|
||||
self._drive_to_center()
|
||||
|
||||
# Wait until the pendulum settles.
|
||||
self._wait_for_pendulum_still()
|
||||
|
||||
# Refresh data timer so health checks don't false-positive.
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
# Read settled state and return as qpos/qvel.
|
||||
state = self._read_state_blocking()
|
||||
motor_angle = (
|
||||
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
motor_vel = (
|
||||
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||
|
||||
qpos = torch.tensor(
|
||||
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||
)
|
||||
qvel = torch.tensor(
|
||||
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||
)
|
||||
return qpos, qvel
|
||||
|
||||
def _sim_close(self) -> None:
|
||||
self._reader_running = False
|
||||
self._streaming = False
|
||||
self._send("H") # Stop streaming.
|
||||
self._send("M0") # Stop motor.
|
||||
time.sleep(0.1)
|
||||
self._reader_thread.join(timeout=1.0)
|
||||
self.ser.close()
|
||||
if hasattr(self, "_offscreen_renderer") and self._offscreen_renderer is not None:
|
||||
self._offscreen_renderer.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# MuJoCo digital-twin rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_viz_model(self) -> None:
|
||||
"""Lazily load the MuJoCo model for visualisation (digital twin).
|
||||
|
||||
Reuses the same URDF + robot.yaml that the MuJoCoRunner would use,
|
||||
but only for rendering — no physics stepping.
|
||||
"""
|
||||
if hasattr(self, "_viz_model"):
|
||||
return
|
||||
|
||||
import mujoco
|
||||
from src.runners.mujoco import MuJoCoRunner
|
||||
|
||||
self._viz_model = MuJoCoRunner._load_model(self.env.robot)
|
||||
self._viz_data = mujoco.MjData(self._viz_model)
|
||||
self._offscreen_renderer = None
|
||||
|
||||
def _sync_viz(self) -> None:
|
||||
"""Copy current serial sensor state into the MuJoCo viz model."""
|
||||
import mujoco
|
||||
|
||||
self._ensure_viz_model()
|
||||
state = self._read_state()
|
||||
|
||||
# Set joint positions from serial data.
|
||||
motor_angle = (
|
||||
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||
self._viz_data.qpos[0] = motor_angle
|
||||
self._viz_data.qpos[1] = pendulum_angle
|
||||
|
||||
# Set joint velocities (for any velocity-dependent visuals).
|
||||
motor_vel = (
|
||||
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||
self._viz_data.qvel[0] = motor_vel
|
||||
self._viz_data.qvel[1] = pendulum_vel
|
||||
|
||||
# Forward kinematics (updates body positions for rendering).
|
||||
mujoco.mj_forward(self._viz_model, self._viz_data)
|
||||
|
||||
def render(self, env_idx: int = 0) -> np.ndarray:
|
||||
"""Offscreen render of the digital-twin MuJoCo model.
|
||||
|
||||
Called by VideoRecordingTrainer during training to capture frames.
|
||||
"""
|
||||
import mujoco
|
||||
|
||||
self._sync_viz()
|
||||
|
||||
if self._offscreen_renderer is None:
|
||||
self._offscreen_renderer = mujoco.Renderer(
|
||||
self._viz_model, width=640, height=480,
|
||||
)
|
||||
self._offscreen_renderer.update_scene(self._viz_data)
|
||||
return self._offscreen_renderer.render().copy()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override step() for runner-level safety
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
# Check for ESP32 reboot / disconnect BEFORE stepping.
|
||||
if self._rebooted or self._serial_disconnected:
|
||||
self._send("M0")
|
||||
# Return a terminal observation with penalty.
|
||||
qpos, qvel = self._make_current_state()
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
reward = torch.tensor([[-100.0]])
|
||||
terminated = torch.tensor([[True]])
|
||||
truncated = torch.tensor([[False]])
|
||||
return obs, reward, terminated, truncated, {"reboot_detected": True}
|
||||
|
||||
# Normal step via BaseRunner (calls _sim_step → env logic).
|
||||
obs, rewards, terminated, truncated, info = super().step(actions)
|
||||
|
||||
# Check connection health after stepping.
|
||||
if not self._check_connection_health():
|
||||
self._send("M0")
|
||||
terminated = torch.tensor([[True]])
|
||||
rewards = torch.tensor([[-100.0]])
|
||||
info["reboot_detected"] = True
|
||||
|
||||
# Check motor angle against hard safety limit.
|
||||
# Uses the cached value from _sim_step — no extra serial read.
|
||||
if self._max_motor_angle_rad > 0:
|
||||
motor_angle = abs(getattr(self, "_last_motor_angle_rad", 0.0))
|
||||
if motor_angle >= self._max_motor_angle_rad:
|
||||
self._send("M0")
|
||||
terminated = torch.tensor([[True]])
|
||||
rewards = torch.tensor([[-100.0]])
|
||||
info["motor_limit_exceeded"] = True
|
||||
|
||||
# Always stop motor on episode end.
|
||||
if terminated.any() or truncated.any():
|
||||
self._send("M0")
|
||||
|
||||
return obs, rewards, terminated, truncated, info
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serial helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send(self, cmd: str) -> None:
|
||||
"""Send a command to the ESP32."""
|
||||
try:
|
||||
self.ser.write(f"{cmd}\n".encode())
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
self._serial_disconnected = True
|
||||
|
||||
def _serial_reader(self) -> None:
|
||||
"""Background thread: continuously read and parse serial lines."""
|
||||
while self._reader_running:
|
||||
try:
|
||||
if self.ser.in_waiting:
|
||||
line = (
|
||||
self.ser.readline()
|
||||
.decode("utf-8", errors="ignore")
|
||||
.strip()
|
||||
)
|
||||
|
||||
# Detect ESP32 reboot: it prints READY on startup.
|
||||
if line.startswith("READY"):
|
||||
self._rebooted = True
|
||||
logger.critical("ESP32 reboot detected: %s", line)
|
||||
continue
|
||||
|
||||
if line.startswith("S,"):
|
||||
parts = line.split(",")
|
||||
if len(parts) >= 12:
|
||||
esp_ms = int(parts[1])
|
||||
enc = int(parts[2])
|
||||
|
||||
# Detect reboot: timestamp jumped backwards.
|
||||
if (
|
||||
self._last_esp_ms > 5000
|
||||
and esp_ms < self._last_esp_ms - 3000
|
||||
):
|
||||
self._rebooted = True
|
||||
logger.critical(
|
||||
"ESP32 reboot detected: timestamp"
|
||||
" %d -> %d",
|
||||
self._last_esp_ms,
|
||||
esp_ms,
|
||||
)
|
||||
|
||||
# Detect reboot: encoder snapped to 0 from
|
||||
# a far position.
|
||||
if (
|
||||
abs(self._last_encoder_count)
|
||||
> self.config.encoder_jump_threshold
|
||||
and abs(enc) < 5
|
||||
):
|
||||
self._rebooted = True
|
||||
logger.critical(
|
||||
"ESP32 reboot detected: encoder"
|
||||
" jumped %d -> %d",
|
||||
self._last_encoder_count,
|
||||
enc,
|
||||
)
|
||||
|
||||
self._last_esp_ms = esp_ms
|
||||
self._last_encoder_count = enc
|
||||
self._last_data_time = time.monotonic()
|
||||
|
||||
parsed: dict[str, Any] = {
|
||||
"timestamp_ms": esp_ms,
|
||||
"encoder_count": enc,
|
||||
"rpm": float(parts[3]),
|
||||
"motor_speed": int(parts[4]),
|
||||
"at_limit": bool(int(parts[5])),
|
||||
"pendulum_angle": float(parts[6]),
|
||||
"pendulum_velocity": float(parts[7]),
|
||||
"target_speed": int(parts[8]),
|
||||
"braking": bool(int(parts[9])),
|
||||
"enc_vel_cps": float(parts[10]),
|
||||
"pendulum_ok": bool(int(parts[11])),
|
||||
}
|
||||
with self._state_lock:
|
||||
self._latest_state = parsed
|
||||
self._state_event.set()
|
||||
else:
|
||||
time.sleep(0.001) # Avoid busy-spinning.
|
||||
except (OSError, self._serial_mod.SerialException) as exc:
|
||||
self._serial_disconnected = True
|
||||
logger.critical("Serial connection lost: %s", exc)
|
||||
break
|
||||
|
||||
def _check_connection_health(self) -> bool:
|
||||
"""Return True if the ESP32 connection appears healthy."""
|
||||
if self._serial_disconnected:
|
||||
logger.critical("ESP32 serial connection lost.")
|
||||
return False
|
||||
if (
|
||||
self._streaming
|
||||
and (time.monotonic() - self._last_data_time)
|
||||
> self.config.no_data_timeout
|
||||
):
|
||||
logger.critical(
|
||||
"No data from ESP32 for %.1f s — possible crash/disconnect.",
|
||||
time.monotonic() - self._last_data_time,
|
||||
)
|
||||
self._rebooted = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def _read_state(self) -> dict[str, Any]:
|
||||
"""Return the most recent state from the reader thread (non-blocking).
|
||||
|
||||
The background thread updates at ~50 Hz and `_sim_step` already
|
||||
sleeps for `dt` before calling this, so the data is always fresh.
|
||||
"""
|
||||
with self._state_lock:
|
||||
return dict(self._latest_state)
|
||||
|
||||
def _read_state_blocking(self, timeout: float = 0.05) -> dict[str, Any]:
|
||||
"""Wait for a fresh sample, then return it.
|
||||
|
||||
Used during reset / settling where we need to guarantee we have
|
||||
a new reading (no prior dt sleep).
|
||||
"""
|
||||
self._state_event.clear()
|
||||
self._state_event.wait(timeout=timeout)
|
||||
with self._state_lock:
|
||||
return dict(self._latest_state)
|
||||
|
||||
def _make_current_state(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Build qpos/qvel from current sensor data (utility)."""
|
||||
state = self._read_state_blocking()
|
||||
motor_angle = (
|
||||
state["encoder_count"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
motor_vel = (
|
||||
state["enc_vel_cps"] / self._counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pendulum_angle = math.radians(state["pendulum_angle"])
|
||||
pendulum_vel = math.radians(state["pendulum_velocity"])
|
||||
|
||||
qpos = torch.tensor(
|
||||
[[motor_angle, pendulum_angle]], dtype=torch.float32
|
||||
)
|
||||
qvel = torch.tensor(
|
||||
[[motor_vel, pendulum_vel]], dtype=torch.float32
|
||||
)
|
||||
return qpos, qvel
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Physical reset helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _drive_to_center(self) -> None:
|
||||
"""Drive the motor back toward encoder=0 using bang-bang control."""
|
||||
rc = self._hw.reset
|
||||
start = time.time()
|
||||
while time.time() - start < rc.drive_timeout:
|
||||
state = self._read_state_blocking()
|
||||
enc = state["encoder_count"]
|
||||
if abs(enc) < rc.deadband:
|
||||
break
|
||||
speed = rc.drive_speed if enc < 0 else -rc.drive_speed
|
||||
self._send(f"M{speed}")
|
||||
time.sleep(0.05)
|
||||
self._send("M0")
|
||||
time.sleep(0.2)
|
||||
|
||||
def _wait_for_pendulum_still(self) -> None:
|
||||
"""Block until the pendulum has settled (angle and velocity near zero)."""
|
||||
rc = self._hw.reset
|
||||
stable_since: float | None = None
|
||||
start = time.monotonic()
|
||||
|
||||
while time.monotonic() - start < rc.settle_timeout:
|
||||
state = self._read_state_blocking()
|
||||
angle_ok = abs(state["pendulum_angle"]) < rc.settle_angle_deg
|
||||
vel_ok = abs(state["pendulum_velocity"]) < rc.settle_vel_dps
|
||||
|
||||
if angle_ok and vel_ok:
|
||||
if stable_since is None:
|
||||
stable_since = time.monotonic()
|
||||
elif time.monotonic() - stable_since >= rc.settle_duration:
|
||||
logger.info(
|
||||
"Pendulum settled after %.2f s",
|
||||
time.monotonic() - start,
|
||||
)
|
||||
return
|
||||
else:
|
||||
stable_since = None
|
||||
time.sleep(0.02)
|
||||
|
||||
logger.warning(
|
||||
"Pendulum did not fully settle within %.1f s — proceeding anyway.",
|
||||
rc.settle_timeout,
|
||||
)
|
||||
1
src/sysid/__init__.py
Normal file
1
src/sysid/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""System identification — tune simulation parameters to match real hardware."""
|
||||
381
src/sysid/capture.py
Normal file
381
src/sysid/capture.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Capture a real-robot trajectory under random excitation (PRBS-style).
|
||||
|
||||
Connects to the ESP32 over serial, sends random PWM commands to excite
|
||||
the system, and records motor + pendulum angles and velocities at ~50 Hz.
|
||||
|
||||
Saves a compressed numpy archive (.npz) that the optimizer can replay
|
||||
in simulation to fit physics parameters.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.capture \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--port /dev/cu.usbserial-0001 \
|
||||
--duration 20
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
# ── Serial protocol helpers (mirrored from SerialRunner) ─────────────
|
||||
|
||||
|
||||
def _parse_state_line(line: str) -> dict[str, Any] | None:
|
||||
"""Parse an ``S,…`` state line from the ESP32."""
|
||||
if not line.startswith("S,"):
|
||||
return None
|
||||
parts = line.split(",")
|
||||
if len(parts) < 12:
|
||||
return None
|
||||
try:
|
||||
return {
|
||||
"timestamp_ms": int(parts[1]),
|
||||
"encoder_count": int(parts[2]),
|
||||
"rpm": float(parts[3]),
|
||||
"motor_speed": int(parts[4]),
|
||||
"at_limit": bool(int(parts[5])),
|
||||
"pendulum_angle": float(parts[6]),
|
||||
"pendulum_velocity": float(parts[7]),
|
||||
"target_speed": int(parts[8]),
|
||||
"braking": bool(int(parts[9])),
|
||||
"enc_vel_cps": float(parts[10]),
|
||||
"pendulum_ok": bool(int(parts[11])),
|
||||
}
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
# ── Background serial reader ─────────────────────────────────────────
|
||||
|
||||
|
||||
class _SerialReader:
|
||||
"""Minimal background reader for the ESP32 serial stream."""
|
||||
|
||||
def __init__(self, port: str, baud: int = 115200):
|
||||
import serial as _serial
|
||||
|
||||
self._serial_mod = _serial
|
||||
self.ser = _serial.Serial(port, baud, timeout=0.05)
|
||||
time.sleep(2) # Wait for ESP32 boot.
|
||||
self.ser.reset_input_buffer()
|
||||
|
||||
self._latest: dict[str, Any] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._event = threading.Event()
|
||||
self._running = True
|
||||
|
||||
self._thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
if self.ser.in_waiting:
|
||||
line = (
|
||||
self.ser.readline()
|
||||
.decode("utf-8", errors="ignore")
|
||||
.strip()
|
||||
)
|
||||
parsed = _parse_state_line(line)
|
||||
if parsed is not None:
|
||||
with self._lock:
|
||||
self._latest = parsed
|
||||
self._event.set()
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
log.critical("serial_lost")
|
||||
break
|
||||
|
||||
def send(self, cmd: str) -> None:
|
||||
try:
|
||||
self.ser.write(f"{cmd}\n".encode())
|
||||
except (OSError, self._serial_mod.SerialException):
|
||||
log.critical("serial_send_failed", cmd=cmd)
|
||||
|
||||
def read(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return dict(self._latest)
|
||||
|
||||
def read_blocking(self, timeout: float = 0.1) -> dict[str, Any]:
|
||||
self._event.clear()
|
||||
self._event.wait(timeout=timeout)
|
||||
return self.read()
|
||||
|
||||
def close(self) -> None:
|
||||
self._running = False
|
||||
self.send("H")
|
||||
self.send("M0")
|
||||
time.sleep(0.1)
|
||||
self._thread.join(timeout=1.0)
|
||||
self.ser.close()
|
||||
|
||||
|
||||
# ── PRBS excitation signal ───────────────────────────────────────────
|
||||
|
||||
|
||||
class _PRBSExcitation:
|
||||
"""Random hold-value excitation with configurable amplitude and hold time.
|
||||
|
||||
At each call to ``__call__``, returns the current PWM value.
|
||||
The value is held for a random duration (``hold_min``–``hold_max`` ms),
|
||||
then a new random value is drawn uniformly from ``[-amplitude, +amplitude]``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
amplitude: int = 180,
|
||||
hold_min_ms: int = 50,
|
||||
hold_max_ms: int = 300,
|
||||
):
|
||||
self.amplitude = amplitude
|
||||
self.hold_min_ms = hold_min_ms
|
||||
self.hold_max_ms = hold_max_ms
|
||||
self._current: int = 0
|
||||
self._switch_time: float = 0.0
|
||||
self._new_value()
|
||||
|
||||
def _new_value(self) -> None:
|
||||
self._current = random.randint(-self.amplitude, self.amplitude)
|
||||
hold_ms = random.randint(self.hold_min_ms, self.hold_max_ms)
|
||||
self._switch_time = time.monotonic() + hold_ms / 1000.0
|
||||
|
||||
def __call__(self) -> int:
|
||||
if time.monotonic() >= self._switch_time:
|
||||
self._new_value()
|
||||
return self._current
|
||||
|
||||
|
||||
# ── Main capture loop ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def capture(
|
||||
robot_path: str | Path,
|
||||
port: str = "/dev/cu.usbserial-0001",
|
||||
baud: int = 115200,
|
||||
duration: float = 20.0,
|
||||
amplitude: int = 180,
|
||||
hold_min_ms: int = 50,
|
||||
hold_max_ms: int = 300,
|
||||
dt: float = 0.02,
|
||||
) -> Path:
|
||||
"""Run the capture procedure and return the path to the saved .npz file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : path to robot asset directory (contains hardware.yaml)
|
||||
port : serial port for ESP32
|
||||
baud : baud rate
|
||||
duration : capture duration in seconds
|
||||
amplitude : max PWM magnitude for excitation (0–255)
|
||||
hold_min_ms / hold_max_ms : random hold time range (ms)
|
||||
dt : target sampling period (seconds), default 50 Hz
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
|
||||
# Load hardware config for encoder conversion + safety.
|
||||
hw_yaml = robot_path / "hardware.yaml"
|
||||
if not hw_yaml.exists():
|
||||
raise FileNotFoundError(f"hardware.yaml not found in {robot_path}")
|
||||
raw_hw = yaml.safe_load(hw_yaml.read_text())
|
||||
ppr = raw_hw.get("encoder", {}).get("ppr", 11)
|
||||
gear_ratio = raw_hw.get("encoder", {}).get("gear_ratio", 30.0)
|
||||
counts_per_rev: float = ppr * gear_ratio * 4.0
|
||||
max_motor_deg = raw_hw.get("safety", {}).get("max_motor_angle_deg", 90.0)
|
||||
max_motor_rad = math.radians(max_motor_deg) if max_motor_deg > 0 else 0.0
|
||||
|
||||
# Connect.
|
||||
reader = _SerialReader(port, baud)
|
||||
excitation = _PRBSExcitation(amplitude, hold_min_ms, hold_max_ms)
|
||||
|
||||
# Prepare recording buffers.
|
||||
max_samples = int(duration / dt) + 500 # headroom
|
||||
rec_time = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_action = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_motor_angle = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_motor_vel = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_pend_angle = np.zeros(max_samples, dtype=np.float64)
|
||||
rec_pend_vel = np.zeros(max_samples, dtype=np.float64)
|
||||
|
||||
# Start streaming.
|
||||
reader.send("G")
|
||||
time.sleep(0.1)
|
||||
|
||||
log.info(
|
||||
"capture_starting",
|
||||
port=port,
|
||||
duration=duration,
|
||||
amplitude=amplitude,
|
||||
hold_range_ms=f"{hold_min_ms}–{hold_max_ms}",
|
||||
dt=dt,
|
||||
)
|
||||
|
||||
idx = 0
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
while True:
|
||||
loop_start = time.monotonic()
|
||||
elapsed = loop_start - t0
|
||||
if elapsed >= duration:
|
||||
break
|
||||
|
||||
# Get excitation PWM.
|
||||
pwm = excitation()
|
||||
|
||||
# Safety: reverse/zero if near motor limit.
|
||||
state = reader.read()
|
||||
if state:
|
||||
motor_angle_rad = (
|
||||
state.get("encoder_count", 0) / counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
if max_motor_rad > 0:
|
||||
margin = max_motor_rad * 0.85 # start braking at 85%
|
||||
if motor_angle_rad > margin and pwm > 0:
|
||||
pwm = -abs(pwm) # reverse
|
||||
elif motor_angle_rad < -margin and pwm < 0:
|
||||
pwm = abs(pwm) # reverse
|
||||
|
||||
# Send command.
|
||||
reader.send(f"M{pwm}")
|
||||
|
||||
# Wait for fresh data.
|
||||
time.sleep(max(0, dt - (time.monotonic() - loop_start) - 0.005))
|
||||
state = reader.read_blocking(timeout=dt)
|
||||
|
||||
if state:
|
||||
enc = state.get("encoder_count", 0)
|
||||
motor_angle = enc / counts_per_rev * 2.0 * math.pi
|
||||
motor_vel = (
|
||||
state.get("enc_vel_cps", 0.0) / counts_per_rev * 2.0 * math.pi
|
||||
)
|
||||
pend_angle = math.radians(state.get("pendulum_angle", 0.0))
|
||||
pend_vel = math.radians(state.get("pendulum_velocity", 0.0))
|
||||
|
||||
# Normalised action: PWM / 255 → [-1, 1]
|
||||
action_norm = pwm / 255.0
|
||||
|
||||
if idx < max_samples:
|
||||
rec_time[idx] = elapsed
|
||||
rec_action[idx] = action_norm
|
||||
rec_motor_angle[idx] = motor_angle
|
||||
rec_motor_vel[idx] = motor_vel
|
||||
rec_pend_angle[idx] = pend_angle
|
||||
rec_pend_vel[idx] = pend_vel
|
||||
idx += 1
|
||||
|
||||
# Progress.
|
||||
if idx % 50 == 0:
|
||||
log.info(
|
||||
"capture_progress",
|
||||
elapsed=f"{elapsed:.1f}/{duration:.0f}s",
|
||||
samples=idx,
|
||||
pwm=pwm,
|
||||
)
|
||||
|
||||
# Pace to dt.
|
||||
remaining = dt - (time.monotonic() - loop_start)
|
||||
if remaining > 0:
|
||||
time.sleep(remaining)
|
||||
|
||||
finally:
|
||||
reader.send("M0")
|
||||
reader.close()
|
||||
|
||||
# Trim to actual sample count.
|
||||
rec_time = rec_time[:idx]
|
||||
rec_action = rec_action[:idx]
|
||||
rec_motor_angle = rec_motor_angle[:idx]
|
||||
rec_motor_vel = rec_motor_vel[:idx]
|
||||
rec_pend_angle = rec_pend_angle[:idx]
|
||||
rec_pend_vel = rec_pend_vel[:idx]
|
||||
|
||||
# Save.
|
||||
recordings_dir = robot_path / "recordings"
|
||||
recordings_dir.mkdir(exist_ok=True)
|
||||
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = recordings_dir / f"capture_{stamp}.npz"
|
||||
np.savez_compressed(
|
||||
out_path,
|
||||
time=rec_time,
|
||||
action=rec_action,
|
||||
motor_angle=rec_motor_angle,
|
||||
motor_vel=rec_motor_vel,
|
||||
pendulum_angle=rec_pend_angle,
|
||||
pendulum_vel=rec_pend_vel,
|
||||
)
|
||||
|
||||
log.info(
|
||||
"capture_saved",
|
||||
path=str(out_path),
|
||||
samples=idx,
|
||||
duration_actual=f"{rec_time[-1]:.2f}s" if idx > 0 else "0s",
|
||||
)
|
||||
return out_path
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Capture a real-robot trajectory for system identification."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
help="Path to robot asset directory (contains hardware.yaml)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=str,
|
||||
default="/dev/cu.usbserial-0001",
|
||||
help="Serial port for ESP32",
|
||||
)
|
||||
parser.add_argument("--baud", type=int, default=115200)
|
||||
parser.add_argument(
|
||||
"--duration", type=float, default=20.0, help="Capture duration (s)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amplitude", type=int, default=180, help="Max PWM magnitude (0–255)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hold-min-ms", type=int, default=50, help="Min hold time (ms)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hold-max-ms", type=int, default=300, help="Max hold time (ms)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dt", type=float, default=0.02, help="Sample period (s)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
capture(
|
||||
robot_path=args.robot_path,
|
||||
port=args.port,
|
||||
baud=args.baud,
|
||||
duration=args.duration,
|
||||
amplitude=args.amplitude,
|
||||
hold_min_ms=args.hold_min_ms,
|
||||
hold_max_ms=args.hold_max_ms,
|
||||
dt=args.dt,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
186
src/sysid/export.py
Normal file
186
src/sysid/export.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Export tuned parameters to URDF and robot.yaml files.
|
||||
|
||||
Reads the original files, injects the optimised parameter values,
|
||||
and writes ``rotary_cartpole_tuned.urdf`` + ``robot_tuned.yaml``
|
||||
alongside the originals in the robot asset directory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import shutil
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def export_tuned_files(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
) -> tuple[Path, Path]:
|
||||
"""Write tuned URDF and robot.yaml files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : robot asset directory (contains robot.yaml + *.urdf)
|
||||
params : dict of parameter name → tuned value (from optimizer)
|
||||
|
||||
Returns
|
||||
-------
|
||||
(tuned_urdf_path, tuned_robot_yaml_path)
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
|
||||
# ── Load originals ───────────────────────────────────────────
|
||||
robot_yaml_path = robot_path / "robot.yaml"
|
||||
robot_cfg = yaml.safe_load(robot_yaml_path.read_text())
|
||||
urdf_path = robot_path / robot_cfg["urdf"]
|
||||
|
||||
# ── Tune URDF ────────────────────────────────────────────────
|
||||
tree = ET.parse(urdf_path)
|
||||
root = tree.getroot()
|
||||
|
||||
for link in root.iter("link"):
|
||||
link_name = link.get("name", "")
|
||||
inertial = link.find("inertial")
|
||||
if inertial is None:
|
||||
continue
|
||||
|
||||
if link_name == "arm":
|
||||
_set_mass(inertial, params.get("arm_mass"))
|
||||
_set_com(
|
||||
inertial,
|
||||
params.get("arm_com_x"),
|
||||
params.get("arm_com_y"),
|
||||
params.get("arm_com_z"),
|
||||
)
|
||||
|
||||
elif link_name == "pendulum":
|
||||
_set_mass(inertial, params.get("pendulum_mass"))
|
||||
_set_com(
|
||||
inertial,
|
||||
params.get("pendulum_com_x"),
|
||||
params.get("pendulum_com_y"),
|
||||
params.get("pendulum_com_z"),
|
||||
)
|
||||
_set_inertia(
|
||||
inertial,
|
||||
ixx=params.get("pendulum_ixx"),
|
||||
iyy=params.get("pendulum_iyy"),
|
||||
izz=params.get("pendulum_izz"),
|
||||
ixy=params.get("pendulum_ixy"),
|
||||
)
|
||||
|
||||
# Write tuned URDF.
|
||||
tuned_urdf_name = urdf_path.stem + "_tuned" + urdf_path.suffix
|
||||
tuned_urdf_path = robot_path / tuned_urdf_name
|
||||
|
||||
# Preserve the XML declaration and original formatting as much as possible.
|
||||
ET.indent(tree, space=" ")
|
||||
tree.write(str(tuned_urdf_path), xml_declaration=True, encoding="unicode")
|
||||
log.info("tuned_urdf_written", path=str(tuned_urdf_path))
|
||||
|
||||
# ── Tune robot.yaml ──────────────────────────────────────────
|
||||
tuned_cfg = copy.deepcopy(robot_cfg)
|
||||
|
||||
# Point to the tuned URDF.
|
||||
tuned_cfg["urdf"] = tuned_urdf_name
|
||||
|
||||
# Update actuator parameters.
|
||||
if tuned_cfg.get("actuators") and len(tuned_cfg["actuators"]) > 0:
|
||||
act = tuned_cfg["actuators"][0]
|
||||
if "actuator_gear" in params:
|
||||
act["gear"] = round(params["actuator_gear"], 6)
|
||||
if "actuator_filter_tau" in params:
|
||||
act["filter_tau"] = round(params["actuator_filter_tau"], 6)
|
||||
if "motor_damping" in params:
|
||||
act["damping"] = round(params["motor_damping"], 6)
|
||||
|
||||
# Update joint overrides.
|
||||
if "joints" not in tuned_cfg:
|
||||
tuned_cfg["joints"] = {}
|
||||
|
||||
if "motor_joint" not in tuned_cfg["joints"]:
|
||||
tuned_cfg["joints"]["motor_joint"] = {}
|
||||
mj = tuned_cfg["joints"]["motor_joint"]
|
||||
if "motor_armature" in params:
|
||||
mj["armature"] = round(params["motor_armature"], 6)
|
||||
if "motor_frictionloss" in params:
|
||||
mj["frictionloss"] = round(params["motor_frictionloss"], 6)
|
||||
|
||||
if "pendulum_joint" not in tuned_cfg["joints"]:
|
||||
tuned_cfg["joints"]["pendulum_joint"] = {}
|
||||
pj = tuned_cfg["joints"]["pendulum_joint"]
|
||||
if "pendulum_damping" in params:
|
||||
pj["damping"] = round(params["pendulum_damping"], 6)
|
||||
|
||||
# Write tuned robot.yaml.
|
||||
tuned_yaml_path = robot_path / "robot_tuned.yaml"
|
||||
|
||||
# Add a header comment.
|
||||
header = (
|
||||
"# Tuned robot config — generated by src.sysid.optimize\n"
|
||||
"# Original: robot.yaml\n"
|
||||
"# Run `python -m src.sysid.visualize` to compare real vs sim.\n\n"
|
||||
)
|
||||
tuned_yaml_path.write_text(
|
||||
header + yaml.dump(tuned_cfg, default_flow_style=False, sort_keys=False)
|
||||
)
|
||||
log.info("tuned_robot_yaml_written", path=str(tuned_yaml_path))
|
||||
|
||||
return tuned_urdf_path, tuned_yaml_path
|
||||
|
||||
|
||||
# ── XML helpers (shared with rollout.py) ─────────────────────────────
|
||||
|
||||
|
||||
def _set_mass(inertial: ET.Element, mass: float | None) -> None:
|
||||
if mass is None:
|
||||
return
|
||||
mass_el = inertial.find("mass")
|
||||
if mass_el is not None:
|
||||
mass_el.set("value", str(mass))
|
||||
|
||||
|
||||
def _set_com(
|
||||
inertial: ET.Element,
|
||||
x: float | None,
|
||||
y: float | None,
|
||||
z: float | None,
|
||||
) -> None:
|
||||
origin = inertial.find("origin")
|
||||
if origin is None:
|
||||
return
|
||||
xyz = origin.get("xyz", "0 0 0").split()
|
||||
if x is not None:
|
||||
xyz[0] = str(x)
|
||||
if y is not None:
|
||||
xyz[1] = str(y)
|
||||
if z is not None:
|
||||
xyz[2] = str(z)
|
||||
origin.set("xyz", " ".join(xyz))
|
||||
|
||||
|
||||
def _set_inertia(
|
||||
inertial: ET.Element,
|
||||
ixx: float | None = None,
|
||||
iyy: float | None = None,
|
||||
izz: float | None = None,
|
||||
ixy: float | None = None,
|
||||
iyz: float | None = None,
|
||||
ixz: float | None = None,
|
||||
) -> None:
|
||||
ine = inertial.find("inertia")
|
||||
if ine is None:
|
||||
return
|
||||
for attr, val in [
|
||||
("ixx", ixx), ("iyy", iyy), ("izz", izz),
|
||||
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
|
||||
]:
|
||||
if val is not None:
|
||||
ine.set(attr, str(val))
|
||||
376
src/sysid/optimize.py
Normal file
376
src/sysid/optimize.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""CMA-ES optimiser — fit simulation parameters to a real-robot recording.
|
||||
|
||||
Minimises the trajectory-matching cost between a MuJoCo rollout and a
|
||||
recorded real-robot sequence. Uses the ``cmaes`` package (pure-Python
|
||||
CMA-ES with native box-constraint support).
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.optimize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||
|
||||
# Shorter run for testing:
|
||||
python -m src.sysid.optimize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording <file>.npz \
|
||||
--max-generations 10 --population-size 8
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
from src.sysid.rollout import (
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
ParamSpec,
|
||||
bounds_arrays,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
rollout,
|
||||
windowed_rollout,
|
||||
)
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
# ── Cost function ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _angle_diff(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
"""Shortest signed angle difference, handling wrapping."""
|
||||
return np.arctan2(np.sin(a - b), np.cos(a - b))
|
||||
|
||||
|
||||
def _check_inertia_valid(params: dict[str, float]) -> bool:
|
||||
"""Quick reject: pendulum inertia tensor must be positive-definite."""
|
||||
ixx = params.get("pendulum_ixx", 6.16e-06)
|
||||
iyy = params.get("pendulum_iyy", 6.16e-06)
|
||||
izz = params.get("pendulum_izz", 1.23e-05)
|
||||
ixy = params.get("pendulum_ixy", 6.10e-06)
|
||||
det_xy = ixx * iyy - ixy * ixy
|
||||
return det_xy > 0 and ixx > 0 and iyy > 0 and izz > 0
|
||||
|
||||
|
||||
def _compute_trajectory_cost(
|
||||
sim: dict[str, np.ndarray],
|
||||
recording: dict[str, np.ndarray],
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
) -> float:
|
||||
"""Weighted MSE between sim and real trajectories."""
|
||||
motor_err = _angle_diff(sim["motor_angle"], recording["motor_angle"])
|
||||
pend_err = _angle_diff(sim["pendulum_angle"], recording["pendulum_angle"])
|
||||
motor_vel_err = sim["motor_vel"] - recording["motor_vel"]
|
||||
pend_vel_err = sim["pendulum_vel"] - recording["pendulum_vel"]
|
||||
|
||||
return float(
|
||||
pos_weight * np.mean(motor_err**2)
|
||||
+ pos_weight * np.mean(pend_err**2)
|
||||
+ vel_weight * np.mean(motor_vel_err**2)
|
||||
+ vel_weight * np.mean(pend_vel_err**2)
|
||||
)
|
||||
|
||||
|
||||
def cost_function(
|
||||
params_vec: np.ndarray,
|
||||
recording: dict[str, np.ndarray],
|
||||
robot_path: Path,
|
||||
specs: list[ParamSpec],
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
window_duration: float = 0.5,
|
||||
) -> float:
|
||||
"""Compute trajectory-matching cost for a candidate parameter vector.
|
||||
|
||||
Uses **multiple-shooting** (windowed rollout): the recording is split
|
||||
into short windows (default 0.5 s). Each window is initialised from
|
||||
the real qpos/qvel, so early errors don’t compound across the full
|
||||
trajectory. This gives a much smoother cost landscape for CMA-ES.
|
||||
|
||||
Set ``window_duration=0`` to fall back to the original open-loop
|
||||
single-shot rollout (not recommended).
|
||||
"""
|
||||
params = params_to_dict(params_vec, specs)
|
||||
|
||||
if not _check_inertia_valid(params):
|
||||
return 1e6
|
||||
|
||||
try:
|
||||
if window_duration > 0:
|
||||
sim = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim = rollout(
|
||||
robot_path=robot_path,
|
||||
params=params,
|
||||
actions=recording["action"],
|
||||
timesteps=recording["time"],
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("rollout_failed", error=str(exc))
|
||||
return 1e6
|
||||
|
||||
return _compute_trajectory_cost(sim, recording, pos_weight, vel_weight)
|
||||
|
||||
|
||||
# ── CMA-ES optimisation loop ────────────────────────────────────────
|
||||
|
||||
|
||||
def optimize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
specs: list[ParamSpec] | None = None,
|
||||
sigma0: float = 0.3,
|
||||
population_size: int = 20,
|
||||
max_generations: int = 1000,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
pos_weight: float = 1.0,
|
||||
vel_weight: float = 0.1,
|
||||
window_duration: float = 0.5,
|
||||
seed: int = 42,
|
||||
) -> dict:
|
||||
"""Run CMA-ES optimisation and return results.
|
||||
|
||||
Returns a dict with:
|
||||
best_params: dict[str, float]
|
||||
best_cost: float
|
||||
history: list of (generation, best_cost) tuples
|
||||
recording: str (path used)
|
||||
specs: list of param names
|
||||
"""
|
||||
from cmaes import CMA
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording_path = Path(recording_path).resolve()
|
||||
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
|
||||
# Load recording.
|
||||
recording = dict(np.load(recording_path))
|
||||
n_samples = len(recording["time"])
|
||||
duration = recording["time"][-1] - recording["time"][0]
|
||||
n_windows = max(1, int(duration / window_duration)) if window_duration > 0 else 1
|
||||
log.info(
|
||||
"recording_loaded",
|
||||
path=str(recording_path),
|
||||
samples=n_samples,
|
||||
duration=f"{duration:.1f}s",
|
||||
window_duration=f"{window_duration}s",
|
||||
n_windows=n_windows,
|
||||
)
|
||||
|
||||
# Initial point (defaults) — normalised to [0, 1] for CMA-ES.
|
||||
lo, hi = bounds_arrays(specs)
|
||||
x0 = defaults_vector(specs)
|
||||
|
||||
# Normalise to [0, 1] for the optimizer (better conditioned).
|
||||
span = hi - lo
|
||||
span[span == 0] = 1.0 # avoid division by zero
|
||||
|
||||
def to_normed(x: np.ndarray) -> np.ndarray:
|
||||
return (x - lo) / span
|
||||
|
||||
def from_normed(x_n: np.ndarray) -> np.ndarray:
|
||||
return x_n * span + lo
|
||||
|
||||
x0_normed = to_normed(x0)
|
||||
bounds_normed = np.column_stack(
|
||||
[np.zeros(len(specs)), np.ones(len(specs))]
|
||||
)
|
||||
|
||||
optimizer = CMA(
|
||||
mean=x0_normed,
|
||||
sigma=sigma0,
|
||||
bounds=bounds_normed,
|
||||
population_size=population_size,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
best_cost = float("inf")
|
||||
best_params_vec = x0.copy()
|
||||
history: list[tuple[int, float]] = []
|
||||
|
||||
log.info(
|
||||
"cmaes_starting",
|
||||
n_params=len(specs),
|
||||
population=population_size,
|
||||
max_gens=max_generations,
|
||||
sigma0=sigma0,
|
||||
)
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
for gen in range(max_generations):
|
||||
solutions = []
|
||||
for _ in range(optimizer.population_size):
|
||||
x_normed = optimizer.ask()
|
||||
x_natural = from_normed(x_normed)
|
||||
|
||||
# Clip to bounds (CMA-ES can slightly exceed with sampling noise).
|
||||
x_natural = np.clip(x_natural, lo, hi)
|
||||
|
||||
c = cost_function(
|
||||
x_natural,
|
||||
recording,
|
||||
robot_path,
|
||||
specs,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
pos_weight=pos_weight,
|
||||
vel_weight=vel_weight,
|
||||
window_duration=window_duration,
|
||||
)
|
||||
solutions.append((x_normed, c))
|
||||
|
||||
if c < best_cost:
|
||||
best_cost = c
|
||||
best_params_vec = x_natural.copy()
|
||||
|
||||
optimizer.tell(solutions)
|
||||
history.append((gen, best_cost))
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
if gen % 5 == 0 or gen == max_generations - 1:
|
||||
log.info(
|
||||
"cmaes_generation",
|
||||
gen=gen,
|
||||
best_cost=f"{best_cost:.6f}",
|
||||
elapsed=f"{elapsed:.1f}s",
|
||||
gen_best=f"{min(c for _, c in solutions):.6f}",
|
||||
)
|
||||
|
||||
total_time = time.monotonic() - t0
|
||||
best_params = params_to_dict(best_params_vec, specs)
|
||||
|
||||
log.info(
|
||||
"cmaes_finished",
|
||||
best_cost=f"{best_cost:.6f}",
|
||||
total_time=f"{total_time:.1f}s",
|
||||
evaluations=max_generations * population_size,
|
||||
)
|
||||
|
||||
# Log parameter comparison.
|
||||
defaults = params_to_dict(defaults_vector(specs), specs)
|
||||
for name in best_params:
|
||||
d = defaults[name]
|
||||
b = best_params[name]
|
||||
change_pct = ((b - d) / abs(d) * 100) if abs(d) > 1e-12 else 0.0
|
||||
log.info(
|
||||
"param_result",
|
||||
name=name,
|
||||
default=f"{d:.6g}",
|
||||
tuned=f"{b:.6g}",
|
||||
change=f"{change_pct:+.1f}%",
|
||||
)
|
||||
|
||||
return {
|
||||
"best_params": best_params,
|
||||
"best_cost": best_cost,
|
||||
"history": history,
|
||||
"recording": str(recording_path),
|
||||
"param_names": [s.name for s in specs],
|
||||
"defaults": {s.name: s.default for s in specs},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fit simulation parameters to a real-robot recording (CMA-ES)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
help="Path to robot asset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recording",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to .npz recording file",
|
||||
)
|
||||
parser.add_argument("--sigma0", type=float, default=0.3)
|
||||
parser.add_argument("--population-size", type=int, default=20)
|
||||
parser.add_argument("--max-generations", type=int, default=200)
|
||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||
parser.add_argument("--substeps", type=int, default=10)
|
||||
parser.add_argument("--pos-weight", type=float, default=1.0)
|
||||
parser.add_argument("--vel-weight", type=float, default=0.1)
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shooting window length in seconds (0 = open-loop, default 0.5)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--no-export",
|
||||
action="store_true",
|
||||
help="Skip exporting tuned files (results JSON only)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
result = optimize(
|
||||
robot_path=args.robot_path,
|
||||
recording_path=args.recording,
|
||||
sigma0=args.sigma0,
|
||||
population_size=args.population_size,
|
||||
max_generations=args.max_generations,
|
||||
sim_dt=args.sim_dt,
|
||||
substeps=args.substeps,
|
||||
pos_weight=args.pos_weight,
|
||||
vel_weight=args.vel_weight,
|
||||
window_duration=args.window_duration,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# Save results JSON.
|
||||
robot_path = Path(args.robot_path).resolve()
|
||||
result_path = robot_path / "sysid_result.json"
|
||||
# Convert numpy types for JSON serialisation.
|
||||
result_json = {
|
||||
k: v for k, v in result.items() if k != "history"
|
||||
}
|
||||
result_json["history_summary"] = {
|
||||
"first_cost": result["history"][0][1] if result["history"] else None,
|
||||
"final_cost": result["history"][-1][1] if result["history"] else None,
|
||||
"generations": len(result["history"]),
|
||||
}
|
||||
result_path.write_text(json.dumps(result_json, indent=2, default=str))
|
||||
log.info("results_saved", path=str(result_path))
|
||||
|
||||
# Export tuned files unless --no-export.
|
||||
if not args.no_export:
|
||||
from src.sysid.export import export_tuned_files
|
||||
|
||||
export_tuned_files(
|
||||
robot_path=args.robot_path,
|
||||
params=result["best_params"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
477
src/sysid/rollout.py
Normal file
477
src/sysid/rollout.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Deterministic simulation replay — roll out recorded actions in MuJoCo.
|
||||
|
||||
Given a parameter vector and a recorded action sequence, builds a MuJoCo
|
||||
model with overridden physics parameters, replays the actions, and returns
|
||||
the simulated trajectory for comparison with the real recording.
|
||||
|
||||
This module is the inner loop of the CMA-ES optimizer: it is called once
|
||||
per candidate parameter vector per generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import math
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
|
||||
# ── Tunable parameter specification ──────────────────────────────────
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParamSpec:
|
||||
"""Specification for a single tunable parameter."""
|
||||
|
||||
name: str
|
||||
default: float
|
||||
lower: float
|
||||
upper: float
|
||||
log_scale: bool = False # optimise in log-space (masses, inertias)
|
||||
|
||||
|
||||
# Default parameter specs for the rotary cartpole.
|
||||
# Order matters: the optimizer maps a flat vector to these specs.
|
||||
ROTARY_CARTPOLE_PARAMS: list[ParamSpec] = [
|
||||
# ── Arm link (URDF) ──────────────────────────────────────────
|
||||
ParamSpec("arm_mass", 0.010, 0.003, 0.05, log_scale=True),
|
||||
ParamSpec("arm_com_x", 0.00005, -0.02, 0.02),
|
||||
ParamSpec("arm_com_y", 0.0065, -0.01, 0.02),
|
||||
ParamSpec("arm_com_z", 0.00563, -0.01, 0.02),
|
||||
# ── Pendulum link (URDF) ─────────────────────────────────────
|
||||
ParamSpec("pendulum_mass", 0.015, 0.005, 0.05, log_scale=True),
|
||||
ParamSpec("pendulum_com_x", 0.1583, 0.05, 0.25),
|
||||
ParamSpec("pendulum_com_y", -0.0983, -0.20, 0.0),
|
||||
ParamSpec("pendulum_com_z", 0.0, -0.05, 0.05),
|
||||
ParamSpec("pendulum_ixx", 6.16e-06, 1e-07, 1e-04, log_scale=True),
|
||||
ParamSpec("pendulum_iyy", 6.16e-06, 1e-07, 1e-04, log_scale=True),
|
||||
ParamSpec("pendulum_izz", 1.23e-05, 1e-07, 1e-04, log_scale=True),
|
||||
ParamSpec("pendulum_ixy", 6.10e-06, -1e-04, 1e-04),
|
||||
# ── Actuator / joint dynamics (robot.yaml) ───────────────────
|
||||
ParamSpec("actuator_gear", 0.064, 0.01, 0.2, log_scale=True),
|
||||
ParamSpec("actuator_filter_tau", 0.03, 0.005, 0.15),
|
||||
ParamSpec("motor_damping", 0.003, 1e-4, 0.05, log_scale=True),
|
||||
ParamSpec("pendulum_damping", 0.0001, 1e-5, 0.01, log_scale=True),
|
||||
ParamSpec("motor_armature", 0.0001, 1e-5, 0.01, log_scale=True),
|
||||
ParamSpec("motor_frictionloss", 0.03, 0.001, 0.1, log_scale=True),
|
||||
]
|
||||
|
||||
|
||||
def params_to_dict(
|
||||
values: np.ndarray, specs: list[ParamSpec] | None = None
|
||||
) -> dict[str, float]:
|
||||
"""Convert a flat parameter vector to a named dict."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
return {s.name: float(values[i]) for i, s in enumerate(specs)}
|
||||
|
||||
|
||||
def defaults_vector(specs: list[ParamSpec] | None = None) -> np.ndarray:
|
||||
"""Return the default parameter vector (in natural space)."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
return np.array([s.default for s in specs], dtype=np.float64)
|
||||
|
||||
|
||||
def bounds_arrays(
|
||||
specs: list[ParamSpec] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (lower, upper) bound arrays."""
|
||||
if specs is None:
|
||||
specs = ROTARY_CARTPOLE_PARAMS
|
||||
lo = np.array([s.lower for s in specs], dtype=np.float64)
|
||||
hi = np.array([s.upper for s in specs], dtype=np.float64)
|
||||
return lo, hi
|
||||
|
||||
|
||||
# ── MuJoCo model building with parameter overrides ──────────────────
|
||||
|
||||
|
||||
def _build_model(
|
||||
robot_path: Path,
|
||||
params: dict[str, float],
|
||||
) -> mujoco.MjModel:
|
||||
"""Build a MuJoCo model from URDF + robot.yaml with parameter overrides.
|
||||
|
||||
Follows the same two-step approach as ``MuJoCoRunner._load_model()``:
|
||||
1. Parse URDF, inject meshdir, load into MuJoCo
|
||||
2. Export MJCF, inject actuators + joint overrides + param overrides, reload
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
robot_yaml = yaml.safe_load((robot_path / "robot.yaml").read_text())
|
||||
urdf_path = robot_path / robot_yaml["urdf"]
|
||||
|
||||
# ── Step 1: Load URDF ────────────────────────────────────────
|
||||
tree = ET.parse(urdf_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Inject meshdir compiler directive.
|
||||
meshdir = None
|
||||
for mesh_el in root.iter("mesh"):
|
||||
fn = mesh_el.get("filename", "")
|
||||
parent = str(Path(fn).parent)
|
||||
if parent and parent != ".":
|
||||
meshdir = parent
|
||||
break
|
||||
if meshdir:
|
||||
mj_ext = ET.SubElement(root, "mujoco")
|
||||
ET.SubElement(
|
||||
mj_ext, "compiler", attrib={"meshdir": meshdir, "balanceinertia": "true"}
|
||||
)
|
||||
|
||||
# Override URDF inertial parameters BEFORE MuJoCo loading.
|
||||
for link in root.iter("link"):
|
||||
link_name = link.get("name", "")
|
||||
inertial = link.find("inertial")
|
||||
if inertial is None:
|
||||
continue
|
||||
|
||||
if link_name == "arm":
|
||||
_set_mass(inertial, params.get("arm_mass"))
|
||||
_set_com(
|
||||
inertial,
|
||||
params.get("arm_com_x"),
|
||||
params.get("arm_com_y"),
|
||||
params.get("arm_com_z"),
|
||||
)
|
||||
|
||||
elif link_name == "pendulum":
|
||||
_set_mass(inertial, params.get("pendulum_mass"))
|
||||
_set_com(
|
||||
inertial,
|
||||
params.get("pendulum_com_x"),
|
||||
params.get("pendulum_com_y"),
|
||||
params.get("pendulum_com_z"),
|
||||
)
|
||||
_set_inertia(
|
||||
inertial,
|
||||
ixx=params.get("pendulum_ixx"),
|
||||
iyy=params.get("pendulum_iyy"),
|
||||
izz=params.get("pendulum_izz"),
|
||||
ixy=params.get("pendulum_ixy"),
|
||||
)
|
||||
|
||||
# Write temp URDF and load.
|
||||
tmp_urdf = robot_path / "_tmp_sysid_load.urdf"
|
||||
tree.write(str(tmp_urdf), xml_declaration=True, encoding="unicode")
|
||||
try:
|
||||
model_raw = mujoco.MjModel.from_xml_path(str(tmp_urdf))
|
||||
finally:
|
||||
tmp_urdf.unlink(missing_ok=True)
|
||||
|
||||
# ── Step 2: Export MJCF, inject actuators + overrides ────────
|
||||
tmp_mjcf = robot_path / "_tmp_sysid_inject.xml"
|
||||
try:
|
||||
mujoco.mj_saveLastXML(str(tmp_mjcf), model_raw)
|
||||
mjcf_root = ET.fromstring(tmp_mjcf.read_text())
|
||||
|
||||
# Actuator.
|
||||
gear = params.get("actuator_gear", robot_yaml["actuators"][0].get("gear", 0.064))
|
||||
filter_tau = params.get(
|
||||
"actuator_filter_tau",
|
||||
robot_yaml["actuators"][0].get("filter_tau", 0.03),
|
||||
)
|
||||
act_cfg = robot_yaml["actuators"][0]
|
||||
ctrl_lo, ctrl_hi = act_cfg.get("ctrl_range", [-1.0, 1.0])
|
||||
|
||||
act_elem = ET.SubElement(mjcf_root, "actuator")
|
||||
attribs: dict[str, str] = {
|
||||
"name": f"{act_cfg['joint']}_motor",
|
||||
"joint": act_cfg["joint"],
|
||||
"gear": str(gear),
|
||||
"ctrlrange": f"{ctrl_lo} {ctrl_hi}",
|
||||
}
|
||||
if filter_tau > 0:
|
||||
attribs["dyntype"] = "filter"
|
||||
attribs["dynprm"] = str(filter_tau)
|
||||
attribs["gaintype"] = "fixed"
|
||||
attribs["biastype"] = "none"
|
||||
ET.SubElement(act_elem, "general", attrib=attribs)
|
||||
else:
|
||||
ET.SubElement(act_elem, "motor", attrib=attribs)
|
||||
|
||||
# Joint overrides.
|
||||
motor_damping = params.get("motor_damping", 0.003)
|
||||
pend_damping = params.get("pendulum_damping", 0.0001)
|
||||
motor_armature = params.get("motor_armature", 0.0001)
|
||||
motor_frictionloss = params.get("motor_frictionloss", 0.03)
|
||||
|
||||
for body in mjcf_root.iter("body"):
|
||||
for jnt in body.findall("joint"):
|
||||
name = jnt.get("name")
|
||||
if name == "motor_joint":
|
||||
jnt.set("damping", str(motor_damping))
|
||||
jnt.set("armature", str(motor_armature))
|
||||
jnt.set("frictionloss", str(motor_frictionloss))
|
||||
elif name == "pendulum_joint":
|
||||
jnt.set("damping", str(pend_damping))
|
||||
|
||||
# Disable self-collision.
|
||||
for geom in mjcf_root.iter("geom"):
|
||||
geom.set("contype", "0")
|
||||
geom.set("conaffinity", "0")
|
||||
|
||||
modified_xml = ET.tostring(mjcf_root, encoding="unicode")
|
||||
tmp_mjcf.write_text(modified_xml)
|
||||
model = mujoco.MjModel.from_xml_path(str(tmp_mjcf))
|
||||
finally:
|
||||
tmp_mjcf.unlink(missing_ok=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _set_mass(inertial: ET.Element, mass: float | None) -> None:
|
||||
if mass is None:
|
||||
return
|
||||
mass_el = inertial.find("mass")
|
||||
if mass_el is not None:
|
||||
mass_el.set("value", str(mass))
|
||||
|
||||
|
||||
def _set_com(
|
||||
inertial: ET.Element,
|
||||
x: float | None,
|
||||
y: float | None,
|
||||
z: float | None,
|
||||
) -> None:
|
||||
origin = inertial.find("origin")
|
||||
if origin is None:
|
||||
return
|
||||
xyz = origin.get("xyz", "0 0 0").split()
|
||||
if x is not None:
|
||||
xyz[0] = str(x)
|
||||
if y is not None:
|
||||
xyz[1] = str(y)
|
||||
if z is not None:
|
||||
xyz[2] = str(z)
|
||||
origin.set("xyz", " ".join(xyz))
|
||||
|
||||
|
||||
def _set_inertia(
|
||||
inertial: ET.Element,
|
||||
ixx: float | None = None,
|
||||
iyy: float | None = None,
|
||||
izz: float | None = None,
|
||||
ixy: float | None = None,
|
||||
iyz: float | None = None,
|
||||
ixz: float | None = None,
|
||||
) -> None:
|
||||
ine = inertial.find("inertia")
|
||||
if ine is None:
|
||||
return
|
||||
for attr, val in [
|
||||
("ixx", ixx), ("iyy", iyy), ("izz", izz),
|
||||
("ixy", ixy), ("iyz", iyz), ("ixz", ixz),
|
||||
]:
|
||||
if val is not None:
|
||||
ine.set(attr, str(val))
|
||||
|
||||
|
||||
# ── Simulation rollout ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def rollout(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
actions: np.ndarray,
|
||||
timesteps: np.ndarray,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Replay recorded actions in MuJoCo with overridden parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : asset directory
|
||||
params : named parameter overrides
|
||||
actions : (N,) normalised actions [-1, 1] from the recording
|
||||
timesteps : (N,) wall-clock times (seconds) from the recording
|
||||
sim_dt : MuJoCo physics timestep
|
||||
substeps : physics substeps per control step
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: motor_angle, motor_vel, pendulum_angle, pendulum_vel
|
||||
Each is an (N,) numpy array of simulated values.
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
model = _build_model(robot_path, params)
|
||||
model.opt.timestep = sim_dt
|
||||
data = mujoco.MjData(model)
|
||||
|
||||
# Start from pendulum hanging down (qpos=0 is down per URDF convention).
|
||||
mujoco.mj_resetData(model, data)
|
||||
|
||||
# Control dt derived from actual recording sample rate.
|
||||
n = len(actions)
|
||||
ctrl_dt = sim_dt * substeps
|
||||
|
||||
# Pre-allocate output.
|
||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||
|
||||
# Extract actuator limit info for software limit switch.
|
||||
nu = model.nu
|
||||
if nu > 0:
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
jnt_limited = bool(model.jnt_limited[jnt_id])
|
||||
jnt_lo = model.jnt_range[jnt_id, 0]
|
||||
jnt_hi = model.jnt_range[jnt_id, 1]
|
||||
gear_sign = float(np.sign(model.actuator_gear[0, 0]))
|
||||
else:
|
||||
jnt_limited = False
|
||||
jnt_lo = jnt_hi = gear_sign = 0.0
|
||||
|
||||
for i in range(n):
|
||||
data.ctrl[0] = actions[i]
|
||||
|
||||
for _ in range(substeps):
|
||||
# Software limit switch (mirrors MuJoCoRunner).
|
||||
if jnt_limited and nu > 0:
|
||||
pos = data.qpos[jnt_id]
|
||||
if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
|
||||
data.ctrl[0] = 0.0
|
||||
elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
|
||||
data.ctrl[0] = 0.0
|
||||
mujoco.mj_step(model, data)
|
||||
|
||||
sim_motor_angle[i] = data.qpos[0]
|
||||
sim_motor_vel[i] = data.qvel[0]
|
||||
sim_pend_angle[i] = data.qpos[1]
|
||||
sim_pend_vel[i] = data.qvel[1]
|
||||
|
||||
return {
|
||||
"motor_angle": sim_motor_angle,
|
||||
"motor_vel": sim_motor_vel,
|
||||
"pendulum_angle": sim_pend_angle,
|
||||
"pendulum_vel": sim_pend_vel,
|
||||
}
|
||||
|
||||
|
||||
def windowed_rollout(
|
||||
robot_path: str | Path,
|
||||
params: dict[str, float],
|
||||
recording: dict[str, np.ndarray],
|
||||
window_duration: float = 0.5,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
) -> dict[str, np.ndarray | float]:
|
||||
"""Multiple-shooting rollout — split recording into short windows.
|
||||
|
||||
For each window:
|
||||
1. Initialize MuJoCo state from the real qpos/qvel at the window start.
|
||||
2. Replay the recorded actions within the window.
|
||||
3. Record the simulated output.
|
||||
|
||||
This prevents error accumulation across the full trajectory, giving
|
||||
a much smoother cost landscape for the optimizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : asset directory
|
||||
params : named parameter overrides
|
||||
recording : dict with keys time, action, motor_angle, motor_vel,
|
||||
pendulum_angle, pendulum_vel (all 1D arrays of length N)
|
||||
window_duration : length of each shooting window in seconds
|
||||
sim_dt : MuJoCo physics timestep
|
||||
substeps : physics substeps per control step
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with:
|
||||
motor_angle, motor_vel, pendulum_angle, pendulum_vel — (N,) arrays
|
||||
(stitched from per-window simulations)
|
||||
n_windows — number of windows used
|
||||
"""
|
||||
robot_path = Path(robot_path).resolve()
|
||||
model = _build_model(robot_path, params)
|
||||
model.opt.timestep = sim_dt
|
||||
data = mujoco.MjData(model)
|
||||
|
||||
times = recording["time"]
|
||||
actions = recording["action"]
|
||||
real_motor = recording["motor_angle"]
|
||||
real_motor_vel = recording["motor_vel"]
|
||||
real_pend = recording["pendulum_angle"]
|
||||
real_pend_vel = recording["pendulum_vel"]
|
||||
n = len(actions)
|
||||
|
||||
# Pre-allocate output (stitched from all windows).
|
||||
sim_motor_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_motor_vel = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_angle = np.zeros(n, dtype=np.float64)
|
||||
sim_pend_vel = np.zeros(n, dtype=np.float64)
|
||||
|
||||
# Extract actuator limit info.
|
||||
nu = model.nu
|
||||
if nu > 0:
|
||||
jnt_id = model.actuator_trnid[0, 0]
|
||||
jnt_limited = bool(model.jnt_limited[jnt_id])
|
||||
jnt_lo = model.jnt_range[jnt_id, 0]
|
||||
jnt_hi = model.jnt_range[jnt_id, 1]
|
||||
gear_sign = float(np.sign(model.actuator_gear[0, 0]))
|
||||
else:
|
||||
jnt_limited = False
|
||||
jnt_lo = jnt_hi = gear_sign = 0.0
|
||||
|
||||
# Compute window boundaries from recording timestamps.
|
||||
t0 = times[0]
|
||||
t_end = times[-1]
|
||||
window_starts: list[int] = [] # indices into the recording
|
||||
current_t = t0
|
||||
while current_t < t_end:
|
||||
# Find the index closest to current_t.
|
||||
idx = int(np.searchsorted(times, current_t))
|
||||
idx = min(idx, n - 1)
|
||||
window_starts.append(idx)
|
||||
current_t += window_duration
|
||||
|
||||
n_windows = len(window_starts)
|
||||
|
||||
for w, w_start in enumerate(window_starts):
|
||||
# Window end: next window start, or end of recording.
|
||||
w_end = window_starts[w + 1] if w + 1 < n_windows else n
|
||||
|
||||
# Initialize MuJoCo state from real data at window start.
|
||||
mujoco.mj_resetData(model, data)
|
||||
data.qpos[0] = real_motor[w_start]
|
||||
data.qpos[1] = real_pend[w_start]
|
||||
data.qvel[0] = real_motor_vel[w_start]
|
||||
data.qvel[1] = real_pend_vel[w_start]
|
||||
data.ctrl[:] = 0.0
|
||||
# Forward kinematics to make state consistent.
|
||||
mujoco.mj_forward(model, data)
|
||||
|
||||
for i in range(w_start, w_end):
|
||||
data.ctrl[0] = actions[i]
|
||||
|
||||
for _ in range(substeps):
|
||||
if jnt_limited and nu > 0:
|
||||
pos = data.qpos[jnt_id]
|
||||
if pos >= jnt_hi and gear_sign * data.ctrl[0] > 0:
|
||||
data.ctrl[0] = 0.0
|
||||
elif pos <= jnt_lo and gear_sign * data.ctrl[0] < 0:
|
||||
data.ctrl[0] = 0.0
|
||||
mujoco.mj_step(model, data)
|
||||
|
||||
sim_motor_angle[i] = data.qpos[0]
|
||||
sim_motor_vel[i] = data.qvel[0]
|
||||
sim_pend_angle[i] = data.qpos[1]
|
||||
sim_pend_vel[i] = data.qvel[1]
|
||||
|
||||
return {
|
||||
"motor_angle": sim_motor_angle,
|
||||
"motor_vel": sim_motor_vel,
|
||||
"pendulum_angle": sim_pend_angle,
|
||||
"pendulum_vel": sim_pend_vel,
|
||||
"n_windows": n_windows,
|
||||
}
|
||||
287
src/sysid/visualize.py
Normal file
287
src/sysid/visualize.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Visualise system identification results — real vs simulated trajectories.
|
||||
|
||||
Loads a recording and runs simulation with both the original and tuned
|
||||
parameters, then plots a 4-panel comparison (motor angle, motor vel,
|
||||
pendulum angle, pendulum vel) over time.
|
||||
|
||||
Usage:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording assets/rotary_cartpole/recordings/capture_20260311_120000.npz
|
||||
|
||||
# Also compare with tuned parameters:
|
||||
python -m src.sysid.visualize \
|
||||
--robot-path assets/rotary_cartpole \
|
||||
--recording <file>.npz \
|
||||
--result assets/rotary_cartpole/sysid_result.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import structlog
|
||||
|
||||
log = structlog.get_logger()
|
||||
|
||||
|
||||
def visualize(
|
||||
robot_path: str | Path,
|
||||
recording_path: str | Path,
|
||||
result_path: str | Path | None = None,
|
||||
sim_dt: float = 0.002,
|
||||
substeps: int = 10,
|
||||
window_duration: float = 0.5,
|
||||
save_path: str | Path | None = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""Generate comparison plot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
robot_path : robot asset directory
|
||||
recording_path : .npz file from capture
|
||||
result_path : sysid_result.json with best_params (optional)
|
||||
sim_dt / substeps : physics settings for rollout
|
||||
window_duration : shooting window length (s); 0 = open-loop
|
||||
save_path : if provided, save figure to this path (PNG, PDF, …)
|
||||
show : if True, display interactive matplotlib window
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from src.sysid.rollout import (
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
defaults_vector,
|
||||
params_to_dict,
|
||||
rollout,
|
||||
windowed_rollout,
|
||||
)
|
||||
|
||||
robot_path = Path(robot_path).resolve()
|
||||
recording = dict(np.load(recording_path))
|
||||
|
||||
t = recording["time"]
|
||||
actions = recording["action"]
|
||||
|
||||
# ── Simulate with default parameters ─────────────────────────
|
||||
default_params = params_to_dict(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS), ROTARY_CARTPOLE_PARAMS
|
||||
)
|
||||
log.info("simulating_default_params", windowed=window_duration > 0)
|
||||
if window_duration > 0:
|
||||
sim_default = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_default = rollout(
|
||||
robot_path=robot_path,
|
||||
params=default_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
# ── Simulate with tuned parameters (if available) ────────────
|
||||
sim_tuned = None
|
||||
tuned_cost = None
|
||||
if result_path is not None:
|
||||
result_path = Path(result_path)
|
||||
if result_path.exists():
|
||||
result = json.loads(result_path.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("simulating_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
log.warning("result_file_not_found", path=str(result_path))
|
||||
else:
|
||||
# Auto-detect sysid_result.json in robot_path.
|
||||
auto_result = robot_path / "sysid_result.json"
|
||||
if auto_result.exists():
|
||||
result = json.loads(auto_result.read_text())
|
||||
tuned_params = result.get("best_params", {})
|
||||
tuned_cost = result.get("best_cost")
|
||||
log.info("auto_detected_tuned_params", cost=tuned_cost)
|
||||
if window_duration > 0:
|
||||
sim_tuned = windowed_rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
recording=recording,
|
||||
window_duration=window_duration,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
else:
|
||||
sim_tuned = rollout(
|
||||
robot_path=robot_path,
|
||||
params=tuned_params,
|
||||
actions=actions,
|
||||
timesteps=t,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
)
|
||||
|
||||
# ── Plot ─────────────────────────────────────────────────────
|
||||
fig, axes = plt.subplots(5, 1, figsize=(14, 12), sharex=True)
|
||||
|
||||
channels = [
|
||||
("motor_angle", "Motor Angle (rad)", True),
|
||||
("motor_vel", "Motor Velocity (rad/s)", False),
|
||||
("pendulum_angle", "Pendulum Angle (rad)", True),
|
||||
("pendulum_vel", "Pendulum Velocity (rad/s)", False),
|
||||
]
|
||||
|
||||
for ax, (key, ylabel, is_angle) in zip(axes[:4], channels):
|
||||
real = recording[key]
|
||||
|
||||
ax.plot(t, real, "k-", linewidth=1.2, alpha=0.8, label="Real")
|
||||
ax.plot(
|
||||
t,
|
||||
sim_default[key],
|
||||
"--",
|
||||
color="#d62728",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (original)",
|
||||
)
|
||||
if sim_tuned is not None:
|
||||
ax.plot(
|
||||
t,
|
||||
sim_tuned[key],
|
||||
"--",
|
||||
color="#2ca02c",
|
||||
linewidth=1.0,
|
||||
alpha=0.7,
|
||||
label="Sim (tuned)",
|
||||
)
|
||||
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.legend(loc="upper right", fontsize=8)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Action plot (bottom panel).
|
||||
axes[4].plot(t, actions, "b-", linewidth=0.8, alpha=0.6)
|
||||
axes[4].set_ylabel("Action (norm)")
|
||||
axes[4].set_xlabel("Time (s)")
|
||||
axes[4].grid(True, alpha=0.3)
|
||||
axes[4].set_ylim(-1.1, 1.1)
|
||||
|
||||
# Title with cost info.
|
||||
title = "System Identification — Real vs Simulated Trajectories"
|
||||
if tuned_cost is not None:
|
||||
# Compute original cost for comparison.
|
||||
from src.sysid.optimize import cost_function
|
||||
|
||||
orig_cost = cost_function(
|
||||
defaults_vector(ROTARY_CARTPOLE_PARAMS),
|
||||
recording,
|
||||
robot_path,
|
||||
ROTARY_CARTPOLE_PARAMS,
|
||||
sim_dt=sim_dt,
|
||||
substeps=substeps,
|
||||
window_duration=window_duration,
|
||||
)
|
||||
title += f"\nOriginal cost: {orig_cost:.4f} → Tuned cost: {tuned_cost:.4f}"
|
||||
improvement = (1.0 - tuned_cost / orig_cost) * 100 if orig_cost > 0 else 0
|
||||
title += f" ({improvement:+.1f}%)"
|
||||
|
||||
fig.suptitle(title, fontsize=12)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
save_path = Path(save_path)
|
||||
fig.savefig(str(save_path), dpi=150, bbox_inches="tight")
|
||||
log.info("figure_saved", path=str(save_path))
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
# ── CLI entry point ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualise system identification results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="assets/rotary_cartpole",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recording",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to .npz recording file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to sysid_result.json (auto-detected if omitted)",
|
||||
)
|
||||
parser.add_argument("--sim-dt", type=float, default=0.002)
|
||||
parser.add_argument("--substeps", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--window-duration",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shooting window length in seconds (0 = open-loop)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Save figure to this path (PNG, PDF, …)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-show",
|
||||
action="store_true",
|
||||
help="Don't show interactive window (useful for CI)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
visualize(
|
||||
robot_path=args.robot_path,
|
||||
recording_path=args.recording,
|
||||
result_path=args.result,
|
||||
sim_dt=args.sim_dt,
|
||||
substeps=args.substeps,
|
||||
window_duration=args.window_duration,
|
||||
save_path=args.save,
|
||||
show=not args.no_show,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -35,6 +35,11 @@ class TrainerConfig:
|
||||
|
||||
hidden_sizes: tuple[int, ...] = (64, 64)
|
||||
|
||||
# Policy
|
||||
initial_log_std: float = 0.5 # initial exploration noise
|
||||
min_log_std: float = -2.0 # minimum exploration noise
|
||||
max_log_std: float = 2.0 # maximum exploration noise (2.0 ≈ σ=7.4)
|
||||
|
||||
# Training
|
||||
total_timesteps: int = 1_000_000
|
||||
log_interval: int = 10
|
||||
@@ -110,6 +115,7 @@ class VideoRecordingTrainer(SequentialTrainer):
|
||||
return self._tcfg.record_video_fps
|
||||
dt = getattr(self.env.config, "dt", 0.02)
|
||||
substeps = getattr(self.env.config, "substeps", 1)
|
||||
# SerialRunner has dt but no substeps — dt *is* the control period.
|
||||
return max(1, int(round(1.0 / (dt * substeps))))
|
||||
|
||||
def _record_video(self, timestep: int) -> None:
|
||||
@@ -181,8 +187,9 @@ class Trainer:
|
||||
action_space=act_space,
|
||||
device=device,
|
||||
hidden_sizes=self.config.hidden_sizes,
|
||||
initial_log_std=0.5,
|
||||
min_log_std=-2.0,
|
||||
initial_log_std=self.config.initial_log_std,
|
||||
min_log_std=self.config.min_log_std,
|
||||
max_log_std=self.config.max_log_std,
|
||||
)
|
||||
|
||||
models = {"policy": self.model, "value": self.model}
|
||||
|
||||
Reference in New Issue
Block a user