From e7f5be3934236ec5ff3ab17516f56f27a6c6d238 Mon Sep 17 00:00:00 2001 From: umisetokikaze Date: Wed, 11 Mar 2026 20:13:00 +0900 Subject: [PATCH] Add LECO training script and associated tests - Implemented `sdxl_train_leco.py` for training with LECO prompts, including argument parsing, model setup, training loop, and weight saving functionality. - Created unit tests for `load_prompt_settings` in `test_leco_train_util.py` to validate loading of prompt configurations in both original and slider formats. - Added basic syntax tests for `train_leco.py` and `sdxl_train_leco.py` to ensure modules are importable. --- docs/train_leco.md | 93 +++++ library/leco_train_util.py | 526 ++++++++++++++++++++++++++ sdxl_train_leco.py | 401 ++++++++++++++++++++ tests/library/test_leco_train_util.py | 71 ++++ tests/test_sdxl_train_leco.py | 5 + tests/test_train_leco.py | 5 + train_leco.py | 378 ++++++++++++++++++ 7 files changed, 1479 insertions(+) create mode 100644 docs/train_leco.md create mode 100644 library/leco_train_util.py create mode 100644 sdxl_train_leco.py create mode 100644 tests/library/test_leco_train_util.py create mode 100644 tests/test_sdxl_train_leco.py create mode 100644 tests/test_train_leco.py create mode 100644 train_leco.py diff --git a/docs/train_leco.md b/docs/train_leco.md new file mode 100644 index 00000000..bc9265c4 --- /dev/null +++ b/docs/train_leco.md @@ -0,0 +1,93 @@ +# LECO Training + +This repository now includes dedicated LECO training entry points: + +- `train_leco.py` for Stable Diffusion 1.x / 2.x +- `sdxl_train_leco.py` for SDXL + +These scripts train a LoRA against the model's own noise predictions, so no image dataset is required. + +## Current scope + +- U-Net LoRA training only +- `networks.lora` is the default network module +- Prompt YAML supports both original LECO prompt pairs and ai-toolkit style slider targets +- Full ai-toolkit job YAML is not supported; use a prompt/target YAML file only + +## Example: SD 1.x / 2.x + +```bash +accelerate launch train_leco.py ^ + --pretrained_model_name_or_path="model.safetensors" ^ + --output_dir="output" ^ + --output_name="detail_slider" ^ + --prompts_file="prompts.yaml" ^ + --network_dim=8 ^ + --network_alpha=4 ^ + --learning_rate=1e-4 ^ + --max_train_steps=500 ^ + --max_denoising_steps=40 ^ + --mixed_precision=bf16 +``` + +## Example: SDXL + +```bash +accelerate launch sdxl_train_leco.py ^ + --pretrained_model_name_or_path="sdxl_model.safetensors" ^ + --output_dir="output" ^ + --output_name="detail_slider_xl" ^ + --prompts_file="slider.yaml" ^ + --network_dim=8 ^ + --network_alpha=4 ^ + --learning_rate=1e-4 ^ + --max_train_steps=500 ^ + --max_denoising_steps=40 ^ + --mixed_precision=bf16 +``` + +## Prompt YAML: original LECO format + +```yaml +- target: "van gogh" + positive: "van gogh" + unconditional: "" + neutral: "" + action: "erase" + guidance_scale: 1.0 + resolution: 512 + batch_size: 1 + multiplier: 1.0 + weight: 1.0 +``` + +## Prompt YAML: ai-toolkit style slider target + +This expands internally into the bidirectional LECO pairs needed for slider-style behavior. + +```yaml +targets: + - target_class: "" + positive: "high detail, intricate, high quality" + negative: "blurry, low detail, low quality" + multiplier: 1.0 + weight: 1.0 + +guidance_scale: 1.0 +resolution: 512 +neutral: "" +``` + +You can also provide multiple neutral prompts: + +```yaml +targets: + - target_class: "person" + positive: "smiling person" + negative: "expressionless person" + +neutrals: + - "" + - "studio photo" + - "cinematic lighting" +``` diff --git a/library/leco_train_util.py b/library/leco_train_util.py new file mode 100644 index 00000000..14da3c3d --- /dev/null +++ b/library/leco_train_util.py @@ -0,0 +1,526 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +import yaml + + +ResolutionValue = Union[int, Tuple[int, int]] + + +@dataclass +class PromptEmbedsXL: + text_embeds: torch.Tensor + pooled_embeds: torch.Tensor + + +class PromptEmbedsCache: + def __init__(self): + self.prompts: dict[str, Any] = {} + + def __setitem__(self, name: str, value: Any) -> None: + self.prompts[name] = value + + def __getitem__(self, name: str) -> Optional[Any]: + return self.prompts.get(name) + + +@dataclass +class PromptSettings: + target: str + positive: Optional[str] = None + unconditional: str = "" + neutral: Optional[str] = None + action: str = "erase" + guidance_scale: float = 1.0 + resolution: ResolutionValue = 512 + dynamic_resolution: bool = False + batch_size: int = 1 + dynamic_crops: bool = False + multiplier: float = 1.0 + weight: float = 1.0 + + def __post_init__(self): + if self.positive is None: + self.positive = self.target + if self.neutral is None: + self.neutral = self.unconditional + if self.action not in ("erase", "enhance"): + raise ValueError(f"Invalid action: {self.action}") + + self.guidance_scale = float(self.guidance_scale) + self.batch_size = int(self.batch_size) + self.multiplier = float(self.multiplier) + self.weight = float(self.weight) + self.dynamic_resolution = bool(self.dynamic_resolution) + self.dynamic_crops = bool(self.dynamic_crops) + self.resolution = normalize_resolution(self.resolution) + + def get_resolution(self) -> Tuple[int, int]: + if isinstance(self.resolution, tuple): + return self.resolution + return (self.resolution, self.resolution) + + def build_target(self, positive_latents, neutral_latents, unconditional_latents): + offset = self.guidance_scale * (positive_latents - unconditional_latents) + if self.action == "erase": + return neutral_latents - offset + return neutral_latents + offset + + +def normalize_resolution(value: Any) -> ResolutionValue: + if isinstance(value, tuple): + if len(value) != 2: + raise ValueError(f"resolution tuple must have 2 items: {value}") + return (int(value[0]), int(value[1])) + if isinstance(value, list): + if len(value) == 2 and all(isinstance(v, (int, float)) for v in value): + return (int(value[0]), int(value[1])) + raise ValueError(f"resolution list must have 2 numeric items: {value}") + return int(value) + + +def _read_non_empty_lines(path: Union[str, Path]) -> List[str]: + with open(path, "r", encoding="utf-8") as f: + return [line.strip() for line in f.readlines() if line.strip()] + + +def _recognized_prompt_keys() -> set[str]: + return { + "target", + "positive", + "unconditional", + "neutral", + "action", + "guidance_scale", + "resolution", + "dynamic_resolution", + "batch_size", + "dynamic_crops", + "multiplier", + "weight", + } + + +def _recognized_slider_keys() -> set[str]: + return { + "target_class", + "positive", + "negative", + "neutral", + "guidance_scale", + "resolution", + "resolutions", + "dynamic_resolution", + "batch_size", + "dynamic_crops", + "multiplier", + "weight", + } + + +def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]: + merged = {k: v for k, v in defaults.items() if k in known_keys} + merged.update(item) + return merged + + +def _normalize_resolution_values(value: Any) -> List[ResolutionValue]: + if value is None: + return [512] + if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): + return [normalize_resolution(v) for v in value] + return [normalize_resolution(value)] + + +def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]: + target_class = str(target.get("target_class", "")) + positive = str(target.get("positive", "") or "") + negative = str(target.get("negative", "") or "") + guidance_scale = target.get("guidance_scale", 1.0) + dynamic_resolution = target.get("dynamic_resolution", False) + batch_size = target.get("batch_size", 1) + dynamic_crops = target.get("dynamic_crops", False) + multiplier = target.get("multiplier", 1.0) + weight = target.get("weight", 1.0) + resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512))) + + if not positive.strip() and not negative.strip(): + raise ValueError("slider target requires either positive or negative prompt") + + prompt_settings: List[PromptSettings] = [] + + for resolution in resolutions: + if positive.strip() and negative.strip(): + prompt_settings.extend( + [ + PromptSettings( + target=target_class, + positive=negative, + unconditional=positive, + neutral=neutral, + action="erase", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=multiplier, + weight=weight, + ), + PromptSettings( + target=target_class, + positive=positive, + unconditional=negative, + neutral=neutral, + action="enhance", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=multiplier, + weight=weight, + ), + PromptSettings( + target=target_class, + positive=positive, + unconditional=negative, + neutral=neutral, + action="erase", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=-multiplier, + weight=weight, + ), + PromptSettings( + target=target_class, + positive=negative, + unconditional=positive, + neutral=neutral, + action="enhance", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=-multiplier, + weight=weight, + ), + ] + ) + elif negative.strip(): + prompt_settings.extend( + [ + PromptSettings( + target=target_class, + positive=negative, + unconditional="", + neutral=neutral, + action="erase", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=multiplier, + weight=weight, + ), + PromptSettings( + target=target_class, + positive=negative, + unconditional="", + neutral=neutral, + action="enhance", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=-multiplier, + weight=weight, + ), + ] + ) + else: + prompt_settings.extend( + [ + PromptSettings( + target=target_class, + positive=positive, + unconditional="", + neutral=neutral, + action="enhance", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=multiplier, + weight=weight, + ), + PromptSettings( + target=target_class, + positive=positive, + unconditional="", + neutral=neutral, + action="erase", + guidance_scale=guidance_scale, + resolution=resolution, + dynamic_resolution=dynamic_resolution, + batch_size=batch_size, + dynamic_crops=dynamic_crops, + multiplier=-multiplier, + weight=weight, + ), + ] + ) + + return prompt_settings + + +def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]: + path = Path(path) + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + + if data is None: + raise ValueError("prompt file is empty") + + default_prompt_values = { + "guidance_scale": 1.0, + "resolution": 512, + "dynamic_resolution": False, + "batch_size": 1, + "dynamic_crops": False, + "multiplier": 1.0, + "weight": 1.0, + } + + prompt_settings: List[PromptSettings] = [] + + def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None: + merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys()) + prompt_settings.append(PromptSettings(**merged)) + + def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None: + merged = _merge_known_defaults(defaults, item, _recognized_slider_keys()) + if not neutral_values: + neutral_values = [str(merged.get("neutral", "") or "")] + for neutral in neutral_values: + prompt_settings.extend(_expand_slider_target(merged, neutral)) + + if isinstance(data, list): + for item in data: + if "target_class" in item: + append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")]) + else: + append_prompt_item(item, default_prompt_values) + elif isinstance(data, dict): + if "prompts" in data: + defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}} + for item in data["prompts"]: + append_prompt_item(item, defaults) + else: + slider_config = data.get("slider", data) + targets = slider_config.get("targets") + if targets is None: + if "target_class" in slider_config: + targets = [slider_config] + elif "target" in slider_config: + targets = [slider_config] + else: + raise ValueError("prompt file does not contain prompts or slider targets") + if len(targets) == 0: + raise ValueError("prompt file contains an empty targets list") + + if "target" in targets[0]: + defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}} + for item in targets: + append_prompt_item(item, defaults) + else: + defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}} + neutral_values: List[str] = [] + if "neutrals" in slider_config: + neutral_values.extend(str(v) for v in slider_config["neutrals"]) + if "neutral_prompt_file" in slider_config: + neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"])) + if "prompt_file" in slider_config: + neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"])) + if not neutral_values: + neutral_values = [str(slider_config.get("neutral", "") or "")] + + for item in targets: + item_neutrals = neutral_values + if "neutrals" in item: + item_neutrals = [str(v) for v in item["neutrals"]] + elif "neutral_prompt_file" in item: + item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"]) + elif "prompt_file" in item: + item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"]) + elif "neutral" in item: + item_neutrals = [str(item["neutral"] or "")] + + append_slider_item(item, defaults, item_neutrals) + else: + raise ValueError("prompt file must be a list or mapping") + + if not prompt_settings: + raise ValueError("no prompt settings found") + + return prompt_settings + + +def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor: + tokens = tokenize_strategy.tokenize(prompt) + return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0] + + +def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL: + tokens = tokenize_strategy.tokenize(prompt) + hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens) + return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2) + + +def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor: + if noise_offset is None: + return latents + return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + + +def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor: + noise = torch.randn( + (batch_size, 4, height // 8, width // 8), + device="cpu", + ).repeat(n_prompts, 1, 1, 1) + return noise * scheduler.init_noise_sigma + + +def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor: + return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0) + + +def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL: + text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0) + pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave( + batch_size, dim=0 + ) + return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds) + + +def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + noise_pred = unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + +def diffusion( + unet, + scheduler, + latents: torch.Tensor, + text_embeddings: torch.Tensor, + total_timesteps: int, + start_timesteps: int = 0, + guidance_scale: float = 3.0, +): + for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: + noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale) + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + return latents + + +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if dynamic_crops: + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + crops_coords_top_left = ( + torch.randint(0, max(original_size[0] - height, 1), (1,)).item(), + torch.randint(0, max(original_size[1] - width, 1), (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype) + if device is not None: + add_time_ids = add_time_ids.to(device) + return add_time_ids + + +def predict_noise_xl( + unet, + scheduler, + timestep, + latents: torch.Tensor, + prompt_embeds: PromptEmbedsXL, + add_time_ids: torch.Tensor, + guidance_scale: float = 1.0, +): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + added_cond_kwargs = {"text_embeds": prompt_embeds.pooled_embeds, "time_ids": add_time_ids} + noise_pred = unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds.text_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + +def diffusion_xl( + unet, + scheduler, + latents: torch.Tensor, + prompt_embeds: PromptEmbedsXL, + add_time_ids: torch.Tensor, + total_timesteps: int, + start_timesteps: int = 0, + guidance_scale: float = 3.0, +): + for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: + noise_pred = predict_noise_xl( + unet, + scheduler, + timestep, + latents, + prompt_embeds=prompt_embeds, + add_time_ids=add_time_ids, + guidance_scale=guidance_scale, + ) + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + return latents + + +def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]: + max_resolution = bucket_resolution + min_resolution = bucket_resolution // 2 + step = 64 + min_step = min_resolution // step + max_step = max_resolution // step + height = torch.randint(min_step, max_step + 1, (1,)).item() * step + width = torch.randint(min_step, max_step + 1, (1,)).item() * step + return height, width + + +def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]: + height, width = prompt.get_resolution() + if prompt.dynamic_resolution and height == width: + return get_random_resolution_in_bucket(height) + return height, width diff --git a/sdxl_train_leco.py b/sdxl_train_leco.py new file mode 100644 index 00000000..ca7a65ea --- /dev/null +++ b/sdxl_train_leco.py @@ -0,0 +1,401 @@ +import argparse +import importlib +import json +import os +import random +from typing import Dict + +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from tqdm import tqdm + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training +from library.leco_train_util import ( + PromptEmbedsCache, + apply_noise_offset, + concat_embeddings, + concat_embeddings_xl, + diffusion_xl, + encode_prompt_sdxl, + get_add_time_ids, + get_initial_latents, + get_random_resolution, + load_prompt_settings, + predict_noise_xl, +) +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + train_util.add_sd_models_arguments(parser) + train_util.add_optimizer_arguments(parser) + train_util.add_training_arguments(parser, support_dreambooth=False) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) + add_logging_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない") + + parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml") + parser.add_argument( + "--max_denoising_steps", + type=int, + default=40, + help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数", + ) + parser.add_argument( + "--leco_denoise_guidance_scale", + type=float, + default=3.0, + help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale", + ) + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network") + parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train") + parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank") + parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha") + parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout") + parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments") + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="unsupported for LECO; kept for compatibility / LECOでは未対応", + ) + parser.add_argument( + "--network_train_unet_only", + action="store_true", + help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習", + ) + parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata") + parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + + return parser + + +def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]: + kwargs = {} + if args.network_args: + for net_arg in args.network_args: + key, value = net_arg.split("=", 1) + kwargs[key] = value + if "dropout" not in kwargs: + kwargs["dropout"] = args.network_dropout + return kwargs + + +def get_save_extension(args: argparse.Namespace) -> str: + if args.save_model_as == "ckpt": + return ".ckpt" + if args.save_model_as == "pt": + return ".pt" + return ".safetensors" + + +def save_weights( + accelerator, + network, + args: argparse.Namespace, + save_dtype, + prompt_settings, + global_step: int, + last: bool = False, +) -> None: + os.makedirs(args.output_dir, exist_ok=True) + ext = get_save_extension(args) + ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + metadata = None + if not args.no_metadata: + metadata = { + "ss_network_module": args.network_module, + "ss_network_dim": str(args.network_dim), + "ss_network_alpha": str(args.network_alpha), + "ss_leco_prompt_count": str(len(prompt_settings)), + "ss_leco_prompts_file": os.path.basename(args.prompts_file), + "ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, + } + if args.training_comment: + metadata["ss_training_comment"] = args.training_comment + metadata["ss_leco_preview"] = json.dumps( + [ + { + "target": p.target, + "positive": p.positive, + "unconditional": p.unconditional, + "neutral": p.neutral, + "action": p.action, + "multiplier": p.multiplier, + "weight": p.weight, + } + for p in prompt_settings[:16] + ], + ensure_ascii=False, + ) + + unwrapped = accelerator.unwrap_model(network) + unwrapped.save_weights(ckpt_file, save_dtype, metadata) + logger.info(f"saved model to: {ckpt_file}") + + +def main(): + parser = setup_parser() + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + train_util.verify_training_args(args) + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) + + if args.output_dir is None: + raise ValueError("--output_dir is required") + if args.network_train_text_encoder_only: + raise ValueError("LECO does not support text encoder LoRA training") + + if args.seed is None: + args.seed = random.randint(0, 2**32 - 1) + set_seed(args.seed) + + accelerator = train_util.prepare_accelerator(args) + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + prompt_settings = load_prompt_settings(args.prompts_file) + logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}") + + _, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + del vae + text_encoders = [text_encoder1, text_encoder2] + + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.train() + + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + prompt_cache = PromptEmbedsCache() + unique_prompts = sorted( + { + prompt + for setting in prompt_settings + for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral) + } + ) + with torch.no_grad(): + for prompt in unique_prompts: + prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt) + + for text_encoder in text_encoders: + text_encoder.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + network_module = importlib.import_module(args.network_module) + net_kwargs = build_network_kwargs(args) + if args.dim_from_weights: + if args.network_weights is None: + raise ValueError("--dim_from_weights requires --network_weights") + network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs) + else: + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + None, + text_encoders, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + + network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True) + network.set_multiplier(0.0) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + logger.info(f"loaded network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() + + unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate + trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet) + + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args) + optimizer_train_fn() + train_util.init_trackers(accelerator, args, "sdxl_leco_train") + + progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + while global_step < args.max_train_steps: + with accelerator.accumulate(network): + optimizer.zero_grad(set_to_none=True) + + setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()] + noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device) + + timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item() + height, width = get_random_resolution(setting) + + latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to( + accelerator.device, dtype=weight_dtype + ) + latents = apply_noise_offset(latents, args.noise_offset) + add_time_ids = get_add_time_ids( + height, + width, + dynamic_crops=setting.dynamic_crops, + dtype=weight_dtype, + device=accelerator.device, + ) + batched_time_ids = concat_embeddings(add_time_ids, add_time_ids, setting.batch_size) + + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + denoised_latents = diffusion_xl( + unet, + noise_scheduler, + latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + add_time_ids=batched_time_ids, + total_timesteps=timesteps_to, + guidance_scale=args.leco_denoise_guidance_scale, + ) + + noise_scheduler.set_timesteps(1000, device=accelerator.device) + current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + network_multiplier.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + positive_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + neutral_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + unconditional_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + target_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + + target = setting.build_target(positive_latents, neutral_latents, unconditional_latents) + loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none") + loss = loss.mean(dim=(1, 2, 3)) + if args.min_snr_gamma is not None and args.min_snr_gamma > 0: + timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + loss = loss.mean() * setting.weight + + accelerator.backward(loss) + + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(0.0) + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "guidance_scale": setting.guidance_scale, + "network_multiplier": setting.multiplier, + } + accelerator.log(logs, step=global_step) + progress_bar.set_postfix(loss=f"{logs['loss']:.4f}") + + if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/tests/library/test_leco_train_util.py b/tests/library/test_leco_train_util.py new file mode 100644 index 00000000..e575614f --- /dev/null +++ b/tests/library/test_leco_train_util.py @@ -0,0 +1,71 @@ +from pathlib import Path + +from library.leco_train_util import load_prompt_settings + + +def test_load_prompt_settings_with_original_format(tmp_path: Path): + prompt_file = tmp_path / "prompts.yaml" + prompt_file.write_text( + """ +- target: "van gogh" + guidance_scale: 1.5 + resolution: 512 +""".strip(), + encoding="utf-8", + ) + + prompts = load_prompt_settings(prompt_file) + + assert len(prompts) == 1 + assert prompts[0].target == "van gogh" + assert prompts[0].positive == "van gogh" + assert prompts[0].unconditional == "" + assert prompts[0].neutral == "" + assert prompts[0].action == "erase" + assert prompts[0].guidance_scale == 1.5 + + +def test_load_prompt_settings_with_slider_targets(tmp_path: Path): + prompt_file = tmp_path / "slider.yaml" + prompt_file.write_text( + """ +targets: + - target_class: "" + positive: "high detail" + negative: "low detail" + multiplier: 1.25 + weight: 0.5 +guidance_scale: 2.0 +resolution: 768 +neutral: "" +""".strip(), + encoding="utf-8", + ) + + prompts = load_prompt_settings(prompt_file) + + assert len(prompts) == 4 + + first = prompts[0] + second = prompts[1] + third = prompts[2] + fourth = prompts[3] + + assert first.target == "" + assert first.positive == "low detail" + assert first.unconditional == "high detail" + assert first.action == "erase" + assert first.multiplier == 1.25 + assert first.weight == 0.5 + assert first.get_resolution() == (768, 768) + + assert second.positive == "high detail" + assert second.unconditional == "low detail" + assert second.action == "enhance" + assert second.multiplier == 1.25 + + assert third.action == "erase" + assert third.multiplier == -1.25 + + assert fourth.action == "enhance" + assert fourth.multiplier == -1.25 diff --git a/tests/test_sdxl_train_leco.py b/tests/test_sdxl_train_leco.py new file mode 100644 index 00000000..165d5396 --- /dev/null +++ b/tests/test_sdxl_train_leco.py @@ -0,0 +1,5 @@ +import sdxl_train_leco + + +def test_syntax(): + assert sdxl_train_leco is not None diff --git a/tests/test_train_leco.py b/tests/test_train_leco.py new file mode 100644 index 00000000..27bb414a --- /dev/null +++ b/tests/test_train_leco.py @@ -0,0 +1,5 @@ +import train_leco + + +def test_syntax(): + assert train_leco is not None diff --git a/train_leco.py b/train_leco.py new file mode 100644 index 00000000..8e7c70c5 --- /dev/null +++ b/train_leco.py @@ -0,0 +1,378 @@ +import argparse +import importlib +import json +import os +import random +from typing import Dict + +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from tqdm import tqdm + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import custom_train_functions, strategy_sd, train_util +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training +from library.leco_train_util import ( + PromptEmbedsCache, + apply_noise_offset, + concat_embeddings, + diffusion, + encode_prompt_sd, + get_initial_latents, + get_random_resolution, + load_prompt_settings, + predict_noise, +) +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + train_util.add_sd_models_arguments(parser) + train_util.add_optimizer_arguments(parser) + train_util.add_training_arguments(parser, support_dreambooth=False) + add_logging_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない") + + parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml") + parser.add_argument( + "--max_denoising_steps", + type=int, + default=40, + help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数", + ) + parser.add_argument( + "--leco_denoise_guidance_scale", + type=float, + default=3.0, + help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale", + ) + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network") + parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train") + parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank") + parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha") + parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout") + parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments") + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="unsupported for LECO; kept for compatibility / LECOでは未対応", + ) + parser.add_argument( + "--network_train_unet_only", + action="store_true", + help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習", + ) + parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata") + parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + + return parser + + +def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]: + kwargs = {} + if args.network_args: + for net_arg in args.network_args: + key, value = net_arg.split("=", 1) + kwargs[key] = value + if "dropout" not in kwargs: + kwargs["dropout"] = args.network_dropout + return kwargs + + +def get_save_extension(args: argparse.Namespace) -> str: + if args.save_model_as == "ckpt": + return ".ckpt" + if args.save_model_as == "pt": + return ".pt" + return ".safetensors" + + +def save_weights( + accelerator, + network, + args: argparse.Namespace, + save_dtype, + prompt_settings, + global_step: int, + last: bool = False, +) -> None: + os.makedirs(args.output_dir, exist_ok=True) + ext = get_save_extension(args) + ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + metadata = None + if not args.no_metadata: + metadata = { + "ss_network_module": args.network_module, + "ss_network_dim": str(args.network_dim), + "ss_network_alpha": str(args.network_alpha), + "ss_leco_prompt_count": str(len(prompt_settings)), + "ss_leco_prompts_file": os.path.basename(args.prompts_file), + } + if args.training_comment: + metadata["ss_training_comment"] = args.training_comment + metadata["ss_leco_preview"] = json.dumps( + [ + { + "target": p.target, + "positive": p.positive, + "unconditional": p.unconditional, + "neutral": p.neutral, + "action": p.action, + "multiplier": p.multiplier, + "weight": p.weight, + } + for p in prompt_settings[:16] + ], + ensure_ascii=False, + ) + + unwrapped = accelerator.unwrap_model(network) + unwrapped.save_weights(ckpt_file, save_dtype, metadata) + logger.info(f"saved model to: {ckpt_file}") + + +def main(): + parser = setup_parser() + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + train_util.verify_training_args(args) + + if args.output_dir is None: + raise ValueError("--output_dir is required") + if args.network_train_text_encoder_only: + raise ValueError("LECO does not support text encoder LoRA training") + + if args.seed is None: + args.seed = random.randint(0, 2**32 - 1) + set_seed(args.seed) + + accelerator = train_util.prepare_accelerator(args) + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + prompt_settings = load_prompt_settings(args.prompts_file) + logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}") + + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + del vae + + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.train() + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + prompt_cache = PromptEmbedsCache() + unique_prompts = sorted( + { + prompt + for setting in prompt_settings + for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral) + } + ) + with torch.no_grad(): + for prompt in unique_prompts: + prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt) + + text_encoder.to("cpu") + clean_memory_on_device(accelerator.device) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + network_module = importlib.import_module(args.network_module) + net_kwargs = build_network_kwargs(args) + if args.dim_from_weights: + if args.network_weights is None: + raise ValueError("--dim_from_weights requires --network_weights") + network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs) + else: + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + None, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + + network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True) + network.set_multiplier(0.0) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + logger.info(f"loaded network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() + + unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate + trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) + + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args) + optimizer_train_fn() + train_util.init_trackers(accelerator, args, "leco_train") + + progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + while global_step < args.max_train_steps: + with accelerator.accumulate(network): + optimizer.zero_grad(set_to_none=True) + + setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()] + noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device) + + timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item() + height, width = get_random_resolution(setting) + + latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to( + accelerator.device, dtype=weight_dtype + ) + latents = apply_noise_offset(latents, args.noise_offset) + + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + denoised_latents = diffusion( + unet, + noise_scheduler, + latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + total_timesteps=timesteps_to, + guidance_scale=args.leco_denoise_guidance_scale, + ) + + noise_scheduler.set_timesteps(1000, device=accelerator.device) + current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + network_multiplier.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + positive_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size), + guidance_scale=1.0, + ) + neutral_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size), + guidance_scale=1.0, + ) + unconditional_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size), + guidance_scale=1.0, + ) + + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + target_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + guidance_scale=1.0, + ) + + target = setting.build_target(positive_latents, neutral_latents, unconditional_latents) + loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none") + loss = loss.mean(dim=(1, 2, 3)) + if args.min_snr_gamma is not None and args.min_snr_gamma > 0: + timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + loss = loss.mean() * setting.weight + + accelerator.backward(loss) + + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(0.0) + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "guidance_scale": setting.guidance_scale, + "network_multiplier": setting.multiplier, + } + accelerator.log(logs, step=global_step) + progress_bar.set_postfix(loss=f"{logs['loss']:.4f}") + + if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main()