mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into dev_device_support
This commit is contained in:
@@ -6,7 +6,10 @@ init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -59,7 +62,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
print("move vae and unet to cpu to save memory")
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
@@ -82,7 +85,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
print("move vae and unet back to original device")
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
@@ -140,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# print("text encoder outputs verified")
|
||||
# logger.info("text encoder outputs verified")
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user