disable JAX GPU preallocation so MJX shares VRAM with torch
This commit is contained in:
@@ -9,10 +9,17 @@ Requirements:
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
|
||||
import structlog
|
||||
import torch
|
||||
|
||||
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
|
||||
# default JAX preallocates ~75% of GPU memory on init, starving torch and
|
||||
# causing OOM at the first PPO update. Disable preallocation so both libraries
|
||||
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
|
||||
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
Reference in New Issue
Block a user