diff --git a/src/runners/mjx.py b/src/runners/mjx.py index ed02079..421e8bf 100644 --- a/src/runners/mjx.py +++ b/src/runners/mjx.py @@ -9,10 +9,17 @@ Requirements: """ import dataclasses +import os import structlog import torch +# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By +# default JAX preallocates ~75% of GPU memory on init, starving torch and +# causing OOM at the first PPO update. Disable preallocation so both libraries +# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice). +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") + try: import jax import jax.numpy as jnp