✨ initial commit
This commit is contained in:
97
src/core/runner.py
Normal file
97
src/core/runner.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import dataclasses
|
||||
import abc
|
||||
from typing import Any, Generic, TypeVar
|
||||
from src.core.env import BaseEnv
|
||||
import torch
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseRunnerConfig:
|
||||
num_envs: int = 1
|
||||
device: str = "cpu"
|
||||
|
||||
class BaseRunner(abc.ABC, Generic[T]):
|
||||
def __init__(self, env: BaseEnv, config: T) -> None:
|
||||
self.env = env
|
||||
self.config = config
|
||||
|
||||
self._sim_initialize(config)
|
||||
|
||||
self.observation_space = self.env.observation_space
|
||||
self.action_space = self.env.action_space
|
||||
self.num_agents: int = 1 # single-agent RL (required by skrl)
|
||||
|
||||
self.step_counts = torch.zeros(
|
||||
self.config.num_envs, dtype=torch.long, device=self.config.device
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def num_envs(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def device(self) -> torch.device:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_initialize(self, config: T) -> None:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_reset(self, env_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _sim_close(self) -> None:
|
||||
...
|
||||
|
||||
def reset(self) -> tuple[torch.Tensor, dict[str, Any]]:
|
||||
all_ids = torch.arange(self.num_envs, device=self.device)
|
||||
qpos, qvel = self._sim_reset(all_ids)
|
||||
self.step_counts.zero_()
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
return obs, {}
|
||||
|
||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
|
||||
qpos, qvel = self._sim_step(actions)
|
||||
self.step_counts += 1
|
||||
|
||||
state = self.env.build_state(qpos, qvel)
|
||||
obs = self.env.compute_observations(state)
|
||||
rewards = self.env.compute_rewards(state, actions)
|
||||
terminated = self.env.compute_terminations(state)
|
||||
truncated = self.env.compute_truncations(self.step_counts)
|
||||
|
||||
info: dict[str, Any] = {}
|
||||
|
||||
done = terminated | truncated
|
||||
done_ids = done.nonzero(as_tuple=False).squeeze(-1)
|
||||
|
||||
if done_ids.numel() > 0:
|
||||
info["final_observations"] = obs[done_ids].clone()
|
||||
info["final_env_ids"] = done_ids.clone()
|
||||
|
||||
reset_qpos, reset_qvel = self._sim_reset(done_ids)
|
||||
self.step_counts[done_ids] = 0
|
||||
|
||||
reset_state = self.env.build_state(reset_qpos, reset_qvel)
|
||||
obs[done_ids] = self.env.compute_observations(reset_state)
|
||||
|
||||
# skrl expects (num_envs, 1) for rewards/terminated/truncated
|
||||
return obs, rewards.unsqueeze(-1), terminated.unsqueeze(-1), truncated.unsqueeze(-1), info
|
||||
|
||||
def render(self, env_idx: int = 0, mode: str = "human") -> torch.Tensor | None:
|
||||
raise NotImplementedError("Render method not implemented for this runner.")
|
||||
|
||||
def close(self) -> None:
|
||||
self._sim_close()
|
||||
Reference in New Issue
Block a user