Files
RL-Sim-Framework/src/core/runner.py
2026-03-06 22:19:44 +01:00

97 lines
3.0 KiB
Python

import dataclasses
import abc
from typing import Any, Generic, TypeVar
from src.core.env import BaseEnv
import torch
T = TypeVar("T")
@dataclasses.dataclass
class BaseRunnerConfig:
num_envs: int = 1
device: str = "cpu"
class BaseRunner(abc.ABC, Generic[T]):
def __init__(self, env: BaseEnv, config: T) -> None:
self.env = env
self.config = config
self._sim_initialize(config)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
self.num_agents: int = 1 # single-agent RL (required by skrl)
self.step_counts = torch.zeros(
self.config.num_envs, dtype=torch.long, device=self.config.device
)
@property
@abc.abstractmethod
def num_envs(self) -> int:
...
@property
@abc.abstractmethod
def device(self) -> torch.device:
...
@abc.abstractmethod
def _sim_initialize(self, config: T) -> None:
...
@abc.abstractmethod
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
@abc.abstractmethod
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
...
@abc.abstractmethod
def _sim_close(self) -> None:
...
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
all_ids = torch.arange(self.num_envs, device=self.device)
qpos, qvel = self._sim_reset(all_ids)
self.step_counts.zero_()
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
return obs, {}
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
qpos, qvel = self._sim_step(actions)
self.step_counts += 1
state = self.env.build_state(qpos, qvel)
obs = self.env.compute_observations(state)
rewards = self.env.compute_rewards(state, actions)
terminated = self.env.compute_terminations(state)
truncated = self.env.compute_truncations(self.step_counts)
info: dict[str, Any] = {}
done = terminated | truncated
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
if done_ids.numel() > 0:
info["final_observations"] = obs[done_ids].clone()
info["final_env_ids"] = done_ids.clone()
reset_qpos, reset_qvel = self._sim_reset(done_ids)
self.step_counts[done_ids] = 0
reset_state = self.env.build_state(reset_qpos, reset_qvel)
obs[done_ids] = self.env.compute_observations(reset_state)
# skrl expects (num_envs, 1) for rewards/terminated/truncated
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
raise NotImplementedError("Render method not implemented for this runner.")
def close(self) -> None:
self._sim_close()