97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
import dataclasses
|
|
import abc
|
|
from typing import Any, Generic, TypeVar
|
|
from src.core.env import BaseEnv
|
|
import torch
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
@dataclasses.dataclass
|
|
class BaseRunnerConfig:
|
|
num_envs: int = 1
|
|
device: str = "cpu"
|
|
|
|
class BaseRunner(abc.ABC, Generic[T]):
|
|
def __init__(self, env: BaseEnv, config: T) -> None:
|
|
self.env = env
|
|
self.config = config
|
|
|
|
self._sim_initialize(config)
|
|
|
|
self.observation_space = self.env.observation_space
|
|
self.action_space = self.env.action_space
|
|
self.num_agents: int = 1 # single-agent RL (required by skrl)
|
|
|
|
self.step_counts = torch.zeros(
|
|
self.config.num_envs, dtype=torch.long, device=self.config.device
|
|
)
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def num_envs(self) -> int:
|
|
...
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def device(self) -> torch.device:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def _sim_initialize(self, config: T) -> None:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def _sim_close(self) -> None:
|
|
...
|
|
|
|
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
|
all_ids = torch.arange(self.num_envs, device=self.device)
|
|
qpos, qvel = self._sim_reset(all_ids)
|
|
self.step_counts.zero_()
|
|
|
|
state = self.env.build_state(qpos, qvel)
|
|
obs = self.env.compute_observations(state)
|
|
return obs, {}
|
|
|
|
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
|
qpos, qvel = self._sim_step(actions)
|
|
self.step_counts += 1
|
|
|
|
state = self.env.build_state(qpos, qvel)
|
|
obs = self.env.compute_observations(state)
|
|
rewards = self.env.compute_rewards(state, actions)
|
|
terminated = self.env.compute_terminations(state)
|
|
truncated = self.env.compute_truncations(self.step_counts)
|
|
|
|
info: dict[str, Any] = {}
|
|
|
|
done = terminated | truncated
|
|
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
|
|
|
if done_ids.numel() > 0:
|
|
info["final_observations"] = obs[done_ids].clone()
|
|
info["final_env_ids"] = done_ids.clone()
|
|
|
|
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
|
self.step_counts[done_ids] = 0
|
|
|
|
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
|
obs[done_ids] = self.env.compute_observations(reset_state)
|
|
|
|
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
|
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
|
|
|
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
|
raise NotImplementedError("Render method not implemented for this runner.")
|
|
|
|
def close(self) -> None:
|
|
self._sim_close() |