jax[cuda12] on linux for GPU; EGL headless render; non-fatal video

This commit is contained in:
2026-06-10 19:25:33 +02:00
parent a6fbde798a
commit 4210b6cb53
3 changed files with 38 additions and 30 deletions

View File

@@ -4,7 +4,8 @@ hydra-core
omegaconf omegaconf
mujoco==3.5.0 mujoco==3.5.0
mujoco-mjx==3.5.0 mujoco-mjx==3.5.0
jax==0.9.1 jax[cuda12]==0.9.1 ; sys_platform == "linux"
jax==0.9.1 ; sys_platform != "linux"
skrl[torch]==1.4.3 skrl[torch]==1.4.3
clearml clearml
imageio imageio

View File

@@ -8,10 +8,12 @@ _PROJECT_ROOT = str(pathlib.Path(__file__).resolve().parent.parent)
if _PROJECT_ROOT not in sys.path: if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT) sys.path.insert(0, _PROJECT_ROOT)
# Headless rendering: use OSMesa on Linux servers (must be set before mujoco import). # Headless rendering on Linux servers (must be set before mujoco import).
# Always default on Linux — Docker containers may have DISPLAY set without a real X server. # EGL renders on the GPU directly (right for NVIDIA nodes) and avoids the
# brittle OSMesa/PyOpenGL stack. Forced (not setdefault) so a stale
# `-e MUJOCO_GL=osmesa` baked into a remote task can't override it.
if sys.platform == "linux": if sys.platform == "linux":
os.environ.setdefault("MUJOCO_GL", "osmesa") os.environ["MUJOCO_GL"] = "egl"
import hydra import hydra
import hydra.utils as hydra_utils import hydra.utils as hydra_utils
@@ -74,11 +76,11 @@ def _init_clearml(choices: dict[str, str], remote: bool = False) -> Task:
"git.victormylle.be/victormylle/simple-rl-framework:latest", "git.victormylle.be/victormylle/simple-rl-framework:latest",
docker_setup_bash_script=( docker_setup_bash_script=(
"apt-get update && apt-get install -y --no-install-recommends " "apt-get update && apt-get install -y --no-install-recommends "
"libosmesa6-dev libgl1-mesa-glx libglfw3 && rm -rf /var/lib/apt/lists/* " "libegl1 libgl1 libglfw3 libosmesa6 && rm -rf /var/lib/apt/lists/* "
"&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0 PyOpenGL PyOpenGL-accelerate" "&& pip install 'jax[cuda12]==0.9.1' mujoco-mjx==3.5.0"
), ),
docker_arguments=[ docker_arguments=[
"-e", "MUJOCO_GL=osmesa", "-e", "MUJOCO_GL=egl",
], ],
) )

View File

@@ -131,35 +131,40 @@ class VideoRecordingTrainer(SequentialTrainer):
except ImportError: except ImportError:
return return
fps = self._get_fps() # Rendering needs a GL backend (EGL/OSMesa); never let a headless GL
max_steps = getattr(self.env.env.config, "max_steps", 500) # failure crash training — log it and skip the video.
frames: list[np.ndarray] = [] try:
fps = self._get_fps()
max_steps = getattr(self.env.env.config, "max_steps", 500)
frames: list[np.ndarray] = []
obs, _ = self.env.reset() obs, _ = self.env.reset()
with torch.no_grad(): with torch.no_grad():
for _ in range(max_steps): for _ in range(max_steps):
action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0] action = self.agents.act(obs, timestep=timestep, timesteps=self.timesteps)[0]
obs, _, terminated, truncated, _ = self.env.step(action) obs, _, terminated, truncated, _ = self.env.step(action)
frame = self.env.render() frame = self.env.render()
if frame is not None: if frame is not None:
frames.append(frame) frames.append(frame)
if (terminated | truncated).any().item(): if (terminated | truncated).any().item():
break break
if frames: if frames:
path = str(self._video_dir / f"step_{timestep}.mp4") path = str(self._video_dir / f"step_{timestep}.mp4")
iio.imwrite(path, frames, fps=fps) iio.imwrite(path, frames, fps=fps)
logger = Logger.current_logger() logger = Logger.current_logger()
if logger: if logger:
logger.report_media( logger.report_media(
"Training Video", f"step_{timestep}", "Training Video", f"step_{timestep}",
local_path=path, iteration=timestep, local_path=path, iteration=timestep,
) )
self.env.reset() self.env.reset()
except Exception as exc:
log.warning("video_recording_failed", timestep=timestep, error=str(exc))
# ── Main trainer ───────────────────────────────────────────────────── # ── Main trainer ─────────────────────────────────────────────────────