Merge branch 'sd3' into feat-hunyuan-image-2.1-inference

This commit is contained in:
Kohya S
2025-09-13 20:13:58 +09:00
18 changed files with 465 additions and 237 deletions

View File

@@ -1,8 +1,10 @@
import functools
import gc
from typing import Optional, Union
import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
@@ -37,12 +39,15 @@ def clean_memory():
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
def clean_memory_on_device(device: Optional[Union[str, torch.device]]):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
@@ -52,7 +57,11 @@ def clean_memory_on_device(device: torch.device):
torch.mps.empty_cache()
def synchronize_device(device: torch.device):
def synchronize_device(device: Optional[Union[str, torch.device]]):
if device is None:
return
if isinstance(device, str):
device = torch.device(device)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":