disable JAX GPU preallocation so MJX shares VRAM with torch

This commit is contained in:
2026-06-10 19:48:48 +02:00
parent 4210b6cb53
commit a98e86ef66

View File

@@ -9,10 +9,17 @@ Requirements:
"""
import dataclasses
import os
import structlog
import torch
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
# default JAX preallocates ~75% of GPU memory on init, starving torch and
# causing OOM at the first PPO update. Disable preallocation so both libraries
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
try:
import jax
import jax.numpy as jnp