From a98e86ef66a7f740c5ecea70afac897a87064ddb Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Wed, 10 Jun 2026 19:48:48 +0200 Subject: [PATCH] disable JAX GPU preallocation so MJX shares VRAM with torch --- src/runners/mjx.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/runners/mjx.py b/src/runners/mjx.py index ed02079..421e8bf 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -9,10 +9,17 @@ Requirements: """ import dataclasses +import os import structlog import torch +# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By +# default JAX preallocates ~75% of GPU memory on init, starving torch and +# causing OOM at the first PPO update. Disable preallocation so both libraries +# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice). +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + try: import jax import jax.numpy as jnp