mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge branch 'sd3' into feat-hunyuan-image-2.1-inference
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import functools
|
||||
import gc
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
# intel gpu support for pytorch older than 2.5
|
||||
# ipex is not needed after pytorch 2.5
|
||||
@@ -37,12 +39,15 @@ def clean_memory():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def clean_memory_on_device(device: torch.device):
|
||||
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
if device is None:
|
||||
return
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
@@ -52,7 +57,11 @@ def clean_memory_on_device(device: torch.device):
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def synchronize_device(device: torch.device):
|
||||
def synchronize_device(device: Optional[Union[str, torch.device]]):
|
||||
if device is None:
|
||||
return
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "xpu":
|
||||
|
||||
Reference in New Issue
Block a user