mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -49,15 +54,32 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
return tokenizer
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def is_text_encoder_outputs_cached(self, args):
|
||||
return args.cache_text_encoder_outputs
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
|
||||
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
||||
else:
|
||||
return None
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
@@ -70,15 +92,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
with accelerator.autocast():
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
dataset.new_cache_text_encoder_outputs(
|
||||
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user