Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -17,7 +17,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
import library.train_util as train_util
@@ -124,7 +124,16 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
@@ -166,10 +175,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -262,8 +271,9 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -276,6 +286,9 @@ def train(args):
train_text_encoder1 = False
train_text_encoder2 = False
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
@@ -307,16 +320,17 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
accelerator.wait_for_everyone()
if not cache_latents:
vae.requires_grad_(False)
@@ -403,7 +417,11 @@ def train(args):
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -597,7 +615,7 @@ def train(args):
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
)
loss_recorder = train_util.LossRecorder()
@@ -628,9 +646,15 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
@@ -646,39 +670,13 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# logger.info("text encoder outputs verified")
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -765,7 +763,7 @@ def train(args):
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
tokenizers,
[text_encoder1, text_encoder2],
unet,
)
@@ -847,7 +845,7 @@ def train(args):
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
tokenizers,
[text_encoder1, text_encoder2],
unet,
)