jax[cuda12] on linux for GPU; EGL headless render; non-fatal video
This commit is contained in:
@@ -4,7 +4,8 @@ hydra-core
|
|||||||
omegaconf
|
omegaconf
|
||||||
mujoco==3.5.0
|
mujoco==3.5.0
|
||||||
mujoco-mjx==3.5.0
|
mujoco-mjx==3.5.0
|
||||||
jax==0.9.1
|
jax[cuda12]==0.9.1 ; sys_platform == "linux"
|
||||||
|
jax==0.9.1 ; sys_platform != "linux"
|
||||||
skrl[torch]==1.4.3
|
skrl[torch]==1.4.3
|
||||||
clearml
|
clearml
|
||||||
imageio
|
imageio
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ _PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
|
|||||||
if _PROJECT_ROOT not in sys.path:
|
if _PROJECT_ROOT not in sys.path:
|
||||||
sys.path.insert(0, _PROJECT_ROOT)
|
sys.path.insert(0, _PROJECT_ROOT)
|
||||||
|
|
||||||
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import).
|
# Headless rendering on Linux servers (must be set before mujoco import).
|
||||||
# Always default on Linux — Docker containers may have DISPLAY set without a real X server.
|
# EGL renders on the GPU directly (right for NVIDIA nodes) and avoids the
|
||||||
|
# brittle OSMesa/PyOpenGL stack. Forced (not setdefault) so a stale
|
||||||
|
# `-e MUJOCO_GL=osmesa` baked into a remote task can't override it.
|
||||||
if sys.platform == "linux":
|
if sys.platform == "linux":
|
||||||
os.environ.setdefault("MUJOCO_GL", "osmesa")
|
os.environ["MUJOCO_GL"] = "egl"
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import hydra.utils as hydra_utils
|
import hydra.utils as hydra_utils
|
||||||
@@ -74,11 +76,11 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
|
|||||||
"git.victormylle.be/victormylle/simple-rl-framework:latest",
|
"git.victormylle.be/victormylle/simple-rl-framework:latest",
|
||||||
docker_setup_bash_script=(
|
docker_setup_bash_script=(
|
||||||
"apt-get update && apt-get install -y --no-install-recommends "
|
"apt-get update && apt-get install -y --no-install-recommends "
|
||||||
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* "
|
"libegl1 libgl1 libglfw3 libosmesa6 && rm -rf /var/lib/apt/lists/* "
|
||||||
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0 PyOpenGL PyOpenGL-accelerate"
|
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0"
|
||||||
),
|
),
|
||||||
docker_arguments=[
|
docker_arguments=[
|
||||||
"-e", "MUJOCO_GL=osmesa",
|
"-e", "MUJOCO_GL=egl",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -131,35 +131,40 @@ class VideoRecordingTrainer(SequentialTrainer):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
return
|
return
|
||||||
|
|
||||||
fps = self._get_fps()
|
# Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
|
||||||
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
# failure crash training — log it and skip the video.
|
||||||
frames: list[np.ndarray] = []
|
try:
|
||||||
|
fps = self._get_fps()
|
||||||
|
max_steps = getattr(self.env.env.config, "max_steps", 500)
|
||||||
|
frames: list[np.ndarray] = []
|
||||||
|
|
||||||
obs, _ = self.env.reset()
|
obs, _ = self.env.reset()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
|
||||||
obs, _, terminated, truncated, _ = self.env.step(action)
|
obs, _, terminated, truncated, _ = self.env.step(action)
|
||||||
|
|
||||||
frame = self.env.render()
|
frame = self.env.render()
|
||||||
if frame is not None:
|
if frame is not None:
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
|
|
||||||
if (terminated | truncated).any().item():
|
if (terminated | truncated).any().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
if frames:
|
if frames:
|
||||||
path = str(self._video_dir / f"step_{timestep}.mp4")
|
path = str(self._video_dir / f"step_{timestep}.mp4")
|
||||||
iio.imwrite(path, frames, fps=fps)
|
iio.imwrite(path, frames, fps=fps)
|
||||||
|
|
||||||
logger = Logger.current_logger()
|
logger = Logger.current_logger()
|
||||||
if logger:
|
if logger:
|
||||||
logger.report_media(
|
logger.report_media(
|
||||||
"Training Video", f"step_{timestep}",
|
"Training Video", f"step_{timestep}",
|
||||||
local_path=path, iteration=timestep,
|
local_path=path, iteration=timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("video_recording_failed", timestep=timestep, error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
# ── Main trainer ─────────────────────────────────────────────────────
|
# ── Main trainer ─────────────────────────────────────────────────────
|
||||||
|
|||||||
Reference in New Issue
Block a user