Files
Kohya-ss-sd-scripts/train_leco.py

380 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import importlib
import json
import os
import random
from typing import Dict
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import custom_train_functions, strategy_sd, train_util
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
from library.leco_train_util import (
PromptEmbedsCache,
apply_noise_offset,
concat_embeddings,
diffusion,
encode_prompt_sd,
get_initial_latents,
get_random_resolution,
load_prompt_settings,
predict_noise,
)
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_optimizer_arguments(parser)
train_util.add_training_arguments(parser, support_dreambooth=False)
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
add_logging_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml")
parser.add_argument(
"--max_denoising_steps",
type=int,
default=40,
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
)
parser.add_argument(
"--leco_denoise_guidance_scale",
type=float,
default=3.0,
help="guidance scale for the partial denoising pass / 部分デイズ時のguidance scale",
)
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
)
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
return parser
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
kwargs = {}
if args.network_args:
for net_arg in args.network_args:
key, value = net_arg.split("=", 1)
kwargs[key] = value
if "dropout" not in kwargs:
kwargs["dropout"] = args.network_dropout
return kwargs
def get_save_extension(args: argparse.Namespace) -> str:
if args.save_model_as == "ckpt":
return ".ckpt"
if args.save_model_as == "pt":
return ".pt"
return ".safetensors"
def save_weights(
accelerator,
network,
args: argparse.Namespace,
save_dtype,
prompt_settings,
global_step: int,
last: bool = False,
) -> None:
os.makedirs(args.output_dir, exist_ok=True)
ext = get_save_extension(args)
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata = None
if not args.no_metadata:
metadata = {
"ss_network_module": args.network_module,
"ss_network_dim": str(args.network_dim),
"ss_network_alpha": str(args.network_alpha),
"ss_leco_prompt_count": str(len(prompt_settings)),
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
}
if args.training_comment:
metadata["ss_training_comment"] = args.training_comment
metadata["ss_leco_preview"] = json.dumps(
[
{
"target": p.target,
"positive": p.positive,
"unconditional": p.unconditional,
"neutral": p.neutral,
"action": p.action,
"multiplier": p.multiplier,
"weight": p.weight,
}
for p in prompt_settings[:16]
],
ensure_ascii=False,
)
unwrapped = accelerator.unwrap_model(network)
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
logger.info(f"saved model to: {ckpt_file}")
def main():
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train_util.verify_training_args(args)
if args.output_dir is None:
raise ValueError("--output_dir is required")
if args.network_train_text_encoder_only:
raise ValueError("LECO does not support text encoder LoRA training")
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
set_seed(args.seed)
accelerator = train_util.prepare_accelerator(args)
weight_dtype, save_dtype = train_util.prepare_dtype(args)
prompt_settings = load_prompt_settings(args.prompts_file)
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
del vae
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.train()
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
prompt_cache = PromptEmbedsCache()
unique_prompts = sorted(
{
prompt
for setting in prompt_settings
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
}
)
with torch.no_grad():
for prompt in unique_prompts:
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
text_encoder.to("cpu")
clean_memory_on_device(accelerator.device)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
network_module = importlib.import_module(args.network_module)
net_kwargs = build_network_kwargs(args)
if args.dim_from_weights:
if args.network_weights is None:
raise ValueError("--dim_from_weights requires --network_weights")
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
None,
text_encoder,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
network.set_multiplier(0.0)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
logger.info(f"loaded network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing()
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
optimizer_train_fn()
train_util.init_trackers(accelerator, args, "leco_train")
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
while global_step < args.max_train_steps:
with accelerator.accumulate(network):
optimizer.zero_grad(set_to_none=True)
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
height, width = get_random_resolution(setting)
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
accelerator.device, dtype=weight_dtype
)
latents = apply_noise_offset(latents, args.noise_offset)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
denoised_latents = diffusion(
unet,
noise_scheduler,
latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
total_timesteps=timesteps_to,
guidance_scale=args.leco_denoise_guidance_scale,
)
noise_scheduler.set_timesteps(1000, device=accelerator.device)
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
network_multiplier.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
positive_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
guidance_scale=1.0,
)
neutral_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
guidance_scale=1.0,
)
unconditional_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
guidance_scale=1.0,
)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
target_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
guidance_scale=1.0,
)
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() * setting.weight
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(0.0)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"guidance_scale": setting.guidance_scale,
"network_multiplier": setting.multiplier,
}
accelerator.log(logs, step=global_step)
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
accelerator.end_training()
if __name__ == "__main__":
main()