mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add clean_memory_on_device and use it from training
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
@@ -64,7 +64,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
with accelerator.autocast():
|
||||
@@ -79,7 +79,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
clean_memory()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
print("move vae and unet back to original device")
|
||||
|
||||
Reference in New Issue
Block a user