Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -7,6 +7,7 @@ import random
import time
import json
from multiprocessing import Value
from typing import Any, List
import toml
from tqdm import tqdm
@@ -18,7 +19,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util
from library.train_util import DreamBoothDataset
@@ -101,19 +102,31 @@ class NetworkTrainer:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
def is_text_encoder_outputs_cached(self, args):
return False
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
def get_text_encoder_outputs_caching_strategy(self, args):
return None
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders
def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
return not args.network_train_unet_only
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
@@ -123,7 +136,7 @@ class NetworkTrainer:
return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
@@ -131,8 +144,8 @@ class NetworkTrainer:
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -150,9 +163,13 @@ class NetworkTrainer:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
# tokenizerは単体またはリスト、tokenizersは必ずリスト既存のコードとの互換性のため
tokenizer = self.load_tokenizer(args)
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
tokenize_strategy = self.get_tokenize_strategy(args)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
@@ -194,11 +211,11 @@ class NetworkTrainer:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -268,8 +285,9 @@ class NetworkTrainer:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -277,9 +295,13 @@ class NetworkTrainer:
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
text_encoding_strategy = self.get_text_encoding_strategy(args)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
if text_encoder_outputs_caching_strategy is not None:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
# prepare network
net_kwargs = {}
@@ -366,7 +388,11 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -878,7 +904,7 @@ class NetworkTrainer:
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
@@ -933,21 +959,31 @@ class NetworkTrainer:
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
else:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# SD only
text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -1026,7 +1062,9 @@ class NetworkTrainer:
progress_bar.update(1)
global_step += 1
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1082,7 +1120,7 @@ class NetworkTrainer:
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
# end of epoch