add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -31,6 +31,21 @@ def clean_memory():
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""