import dataclasses import abc from typing import Any, Generic, TypeVar from src.core.env import BaseEnv import torch T = TypeVar("T") @dataclasses.dataclass class BaseRunnerConfig: num_envs: int = 1 device: str = "cpu" class BaseRunner(abc.ABC, Generic[T]): def __init__(self, env: BaseEnv, config: T) -> None: self.env = env self.config = config self._sim_initialize(config) self.observation_space = self.env.observation_space self.action_space = self.env.action_space self.num_agents: int = 1 # single-agent RL (required by skrl) self.step_counts = torch.zeros( self.config.num_envs, dtype=torch.long, device=self.config.device ) @property @abc.abstractmethod def num_envs(self) -> int: ... @property @abc.abstractmethod def device(self) -> torch.device: ... @abc.abstractmethod def _sim_initialize(self, config: T) -> None: ... @abc.abstractmethod def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ... @abc.abstractmethod def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ... @abc.abstractmethod def _sim_close(self) -> None: ... def reset(self) -> tuple[torch.Tensor, dict[str, Any]]: all_ids = torch.arange(self.num_envs, device=self.device) qpos, qvel = self._sim_reset(all_ids) self.step_counts.zero_() state = self.env.build_state(qpos, qvel) obs = self.env.compute_observations(state) return obs, {} def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: qpos, qvel = self._sim_step(actions) self.step_counts += 1 state = self.env.build_state(qpos, qvel) obs = self.env.compute_observations(state) rewards = self.env.compute_rewards(state, actions) terminated = self.env.compute_terminations(state) truncated = self.env.compute_truncations(self.step_counts) info: dict[str, Any] = {} done = terminated | truncated done_ids = done.nonzero(as_tuple=False).squeeze(-1) if done_ids.numel() > 0: info["final_observations"] = obs[done_ids].clone() info["final_env_ids"] = done_ids.clone() reset_qpos, reset_qvel = self._sim_reset(done_ids) self.step_counts[done_ids] = 0 reset_state = self.env.build_state(reset_qpos, reset_qvel) obs[done_ids] = self.env.compute_observations(reset_state) # skrl expects (num_envs, 1) for rewards/terminated/truncated return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None: raise NotImplementedError("Render method not implemented for this runner.") def close(self) -> None: self._sim_close()