disable JAX GPU preallocation so MJX shares VRAM with torch
This commit is contained in:
@@ -9,10 +9,17 @@ Requirements:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import os
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# JAX (MJX physics) shares the GPU with PyTorch (policy + optimizer). By
|
||||||
|
# default JAX preallocates ~75% of GPU memory on init, starving torch and
|
||||||
|
# causing OOM at the first PPO update. Disable preallocation so both libraries
|
||||||
|
# grow on demand — essential on small vGPU slices (e.g. a 6 GB HAMI slice).
|
||||||
|
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|||||||
Reference in New Issue
Block a user