mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
122
train_network.py
122
train_network.py
@@ -7,6 +7,7 @@ import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -18,7 +19,7 @@ init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import deepspeed_utils, model_util
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset
|
||||
@@ -101,19 +102,31 @@ class NetworkTrainer:
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
return tokenizer
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def is_text_encoder_outputs_cached(self, args):
|
||||
return False
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
|
||||
return [tokenize_strategy.tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
return None
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
return text_encoders
|
||||
|
||||
def is_train_text_encoder(self, args):
|
||||
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
||||
return not args.network_train_unet_only
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
):
|
||||
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -123,7 +136,7 @@ class NetworkTrainer:
|
||||
return encoder_hidden_states
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
|
||||
def all_reduce_network(self, accelerator, network):
|
||||
@@ -131,8 +144,8 @@ class NetworkTrainer:
|
||||
if param.grad is not None:
|
||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -150,9 +163,13 @@ class NetworkTrainer:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
|
||||
tokenizer = self.load_tokenizer(args)
|
||||
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
||||
tokenize_strategy = self.get_tokenize_strategy(args)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
|
||||
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = self.get_latents_caching_strategy(args)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
@@ -194,11 +211,11 @@ class NetworkTrainer:
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -268,8 +285,9 @@ class NetworkTrainer:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -277,9 +295,13 @@ class NetworkTrainer:
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
text_encoding_strategy = self.get_text_encoding_strategy(args)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
text_encoder_outputs_caching_strategy = self.get_text_encoder_outputs_caching_strategy(args)
|
||||
if text_encoder_outputs_caching_strategy is not None:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
@@ -366,7 +388,11 @@ class NetworkTrainer:
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
@@ -878,7 +904,7 @@ class NetworkTrainer:
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
@@ -933,21 +959,31 @@ class NetworkTrainer:
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
text_encoder_conds = self.get_text_cond(
|
||||
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
||||
)
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
else:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# SD only
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizers[0],
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -1026,7 +1062,9 @@ class NetworkTrainer:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
self.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1082,7 +1120,7 @@ class NetworkTrainer:
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user