This commit is contained in:
Umisetokikaze
2026-03-19 11:46:05 +09:00
committed by GitHub
10 changed files with 1573 additions and 11 deletions

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
.codex-tmp

93
docs/train_leco.md Normal file
View File

@@ -0,0 +1,93 @@
# LECO Training
This repository now includes dedicated LECO training entry points:
- `train_leco.py` for Stable Diffusion 1.x / 2.x
- `sdxl_train_leco.py` for SDXL
These scripts train a LoRA against the model's own noise predictions, so no image dataset is required.
## Current scope
- U-Net LoRA training only
- `networks.lora` is the default network module
- Prompt YAML supports both original LECO prompt pairs and ai-toolkit style slider targets
- Full ai-toolkit job YAML is not supported; use a prompt/target YAML file only
## Example: SD 1.x / 2.x
```bash
accelerate launch train_leco.py ^
--pretrained_model_name_or_path="model.safetensors" ^
--output_dir="output" ^
--output_name="detail_slider" ^
--prompts_file="prompts.yaml" ^
--network_dim=8 ^
--network_alpha=4 ^
--learning_rate=1e-4 ^
--max_train_steps=500 ^
--max_denoising_steps=40 ^
--mixed_precision=bf16
```
## Example: SDXL
```bash
accelerate launch sdxl_train_leco.py ^
--pretrained_model_name_or_path="sdxl_model.safetensors" ^
--output_dir="output" ^
--output_name="detail_slider_xl" ^
--prompts_file="slider.yaml" ^
--network_dim=8 ^
--network_alpha=4 ^
--learning_rate=1e-4 ^
--max_train_steps=500 ^
--max_denoising_steps=40 ^
--mixed_precision=bf16
```
## Prompt YAML: original LECO format
```yaml
- target: "van gogh"
positive: "van gogh"
unconditional: ""
neutral: ""
action: "erase"
guidance_scale: 1.0
resolution: 512
batch_size: 1
multiplier: 1.0
weight: 1.0
```
## Prompt YAML: ai-toolkit style slider target
This expands internally into the bidirectional LECO pairs needed for slider-style behavior.
```yaml
targets:
- target_class: ""
positive: "high detail, intricate, high quality"
negative: "blurry, low detail, low quality"
multiplier: 1.0
weight: 1.0
guidance_scale: 1.0
resolution: 512
neutral: ""
```
You can also provide multiple neutral prompts:
```yaml
targets:
- target_class: "person"
positive: "smiling person"
negative: "expressionless person"
neutrals:
- ""
- "studio photo"
- "cinematic lighting"
```

View File

@@ -62,7 +62,7 @@ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
if not getattr(args, "deepspeed", False):
return
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
@@ -70,7 +70,7 @@ def prepare_deepspeed_args(args: argparse.Namespace):
def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
if not getattr(args, "deepspeed", False):
return None
try:

542
library/leco_train_util.py Normal file
View File

@@ -0,0 +1,542 @@
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import yaml
from torch.utils.checkpoint import checkpoint
from library import sdxl_train_util
ResolutionValue = Union[int, Tuple[int, int]]
@dataclass
class PromptEmbedsXL:
text_embeds: torch.Tensor
pooled_embeds: torch.Tensor
class PromptEmbedsCache:
def __init__(self):
self.prompts: dict[str, Any] = {}
def __setitem__(self, name: str, value: Any) -> None:
self.prompts[name] = value
def __getitem__(self, name: str) -> Optional[Any]:
return self.prompts.get(name)
@dataclass
class PromptSettings:
target: str
positive: Optional[str] = None
unconditional: str = ""
neutral: Optional[str] = None
action: str = "erase"
guidance_scale: float = 1.0
resolution: ResolutionValue = 512
dynamic_resolution: bool = False
batch_size: int = 1
dynamic_crops: bool = False
multiplier: float = 1.0
weight: float = 1.0
def __post_init__(self):
if self.positive is None:
self.positive = self.target
if self.neutral is None:
self.neutral = self.unconditional
if self.action not in ("erase", "enhance"):
raise ValueError(f"Invalid action: {self.action}")
self.guidance_scale = float(self.guidance_scale)
self.batch_size = int(self.batch_size)
self.multiplier = float(self.multiplier)
self.weight = float(self.weight)
self.dynamic_resolution = bool(self.dynamic_resolution)
self.dynamic_crops = bool(self.dynamic_crops)
self.resolution = normalize_resolution(self.resolution)
def get_resolution(self) -> Tuple[int, int]:
if isinstance(self.resolution, tuple):
return self.resolution
return (self.resolution, self.resolution)
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
offset = self.guidance_scale * (positive_latents - unconditional_latents)
if self.action == "erase":
return neutral_latents - offset
return neutral_latents + offset
def normalize_resolution(value: Any) -> ResolutionValue:
if isinstance(value, tuple):
if len(value) != 2:
raise ValueError(f"resolution tuple must have 2 items: {value}")
return (int(value[0]), int(value[1]))
if isinstance(value, list):
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
return (int(value[0]), int(value[1]))
raise ValueError(f"resolution list must have 2 numeric items: {value}")
return int(value)
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
with open(path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines() if line.strip()]
def _recognized_prompt_keys() -> set[str]:
return {
"target",
"positive",
"unconditional",
"neutral",
"action",
"guidance_scale",
"resolution",
"dynamic_resolution",
"batch_size",
"dynamic_crops",
"multiplier",
"weight",
}
def _recognized_slider_keys() -> set[str]:
return {
"target_class",
"positive",
"negative",
"neutral",
"guidance_scale",
"resolution",
"resolutions",
"dynamic_resolution",
"batch_size",
"dynamic_crops",
"multiplier",
"weight",
}
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
merged = {k: v for k, v in defaults.items() if k in known_keys}
merged.update(item)
return merged
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
if value is None:
return [512]
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
return [normalize_resolution(v) for v in value]
return [normalize_resolution(value)]
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
target_class = str(target.get("target_class", ""))
positive = str(target.get("positive", "") or "")
negative = str(target.get("negative", "") or "")
guidance_scale = target.get("guidance_scale", 1.0)
dynamic_resolution = target.get("dynamic_resolution", False)
batch_size = target.get("batch_size", 1)
dynamic_crops = target.get("dynamic_crops", False)
multiplier = target.get("multiplier", 1.0)
weight = target.get("weight", 1.0)
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
if not positive.strip() and not negative.strip():
raise ValueError("slider target requires either positive or negative prompt")
prompt_settings: List[PromptSettings] = []
for resolution in resolutions:
if positive.strip() and negative.strip():
prompt_settings.extend(
[
PromptSettings(
target=target_class,
positive=negative,
unconditional=positive,
neutral=neutral,
action="erase",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=multiplier,
weight=weight,
),
PromptSettings(
target=target_class,
positive=positive,
unconditional=negative,
neutral=neutral,
action="enhance",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=multiplier,
weight=weight,
),
PromptSettings(
target=target_class,
positive=positive,
unconditional=negative,
neutral=neutral,
action="erase",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=-multiplier,
weight=weight,
),
PromptSettings(
target=target_class,
positive=negative,
unconditional=positive,
neutral=neutral,
action="enhance",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=-multiplier,
weight=weight,
),
]
)
elif negative.strip():
prompt_settings.extend(
[
PromptSettings(
target=target_class,
positive=negative,
unconditional="",
neutral=neutral,
action="erase",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=multiplier,
weight=weight,
),
PromptSettings(
target=target_class,
positive=negative,
unconditional="",
neutral=neutral,
action="enhance",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=-multiplier,
weight=weight,
),
]
)
else:
prompt_settings.extend(
[
PromptSettings(
target=target_class,
positive=positive,
unconditional="",
neutral=neutral,
action="enhance",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=multiplier,
weight=weight,
),
PromptSettings(
target=target_class,
positive=positive,
unconditional="",
neutral=neutral,
action="erase",
guidance_scale=guidance_scale,
resolution=resolution,
dynamic_resolution=dynamic_resolution,
batch_size=batch_size,
dynamic_crops=dynamic_crops,
multiplier=-multiplier,
weight=weight,
),
]
)
return prompt_settings
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
path = Path(path)
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if data is None:
raise ValueError("prompt file is empty")
default_prompt_values = {
"guidance_scale": 1.0,
"resolution": 512,
"dynamic_resolution": False,
"batch_size": 1,
"dynamic_crops": False,
"multiplier": 1.0,
"weight": 1.0,
}
prompt_settings: List[PromptSettings] = []
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
prompt_settings.append(PromptSettings(**merged))
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
if not neutral_values:
neutral_values = [str(merged.get("neutral", "") or "")]
for neutral in neutral_values:
prompt_settings.extend(_expand_slider_target(merged, neutral))
if isinstance(data, list):
for item in data:
if "target_class" in item:
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
else:
append_prompt_item(item, default_prompt_values)
elif isinstance(data, dict):
if "prompts" in data:
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
for item in data["prompts"]:
append_prompt_item(item, defaults)
else:
slider_config = data.get("slider", data)
targets = slider_config.get("targets")
if targets is None:
if "target_class" in slider_config:
targets = [slider_config]
elif "target" in slider_config:
targets = [slider_config]
else:
raise ValueError("prompt file does not contain prompts or slider targets")
if len(targets) == 0:
raise ValueError("prompt file contains an empty targets list")
if "target" in targets[0]:
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
for item in targets:
append_prompt_item(item, defaults)
else:
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
neutral_values: List[str] = []
if "neutrals" in slider_config:
neutral_values.extend(str(v) for v in slider_config["neutrals"])
if "neutral_prompt_file" in slider_config:
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
if "prompt_file" in slider_config:
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
if not neutral_values:
neutral_values = [str(slider_config.get("neutral", "") or "")]
for item in targets:
item_neutrals = neutral_values
if "neutrals" in item:
item_neutrals = [str(v) for v in item["neutrals"]]
elif "neutral_prompt_file" in item:
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
elif "prompt_file" in item:
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
elif "neutral" in item:
item_neutrals = [str(item["neutral"] or "")]
append_slider_item(item, defaults, item_neutrals)
else:
raise ValueError("prompt file must be a list or mapping")
if not prompt_settings:
raise ValueError("no prompt settings found")
return prompt_settings
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
tokens = tokenize_strategy.tokenize(prompt)
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
tokens = tokenize_strategy.tokenize(prompt)
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
if noise_offset is None:
return latents
return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
noise = torch.randn(
(batch_size, 4, height // 8, width // 8),
device="cpu",
).repeat(n_prompts, 1, 1, 1)
return noise * scheduler.init_noise_sigma
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(
batch_size, dim=0
)
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
def _run_with_checkpoint(function, *args):
if torch.is_grad_enabled():
return checkpoint(function, *args, use_reentrant=False)
return function(*args)
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
def run_unet(model_input, encoder_hidden_states):
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion(
unet,
scheduler,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
total_timesteps: int,
start_timesteps: int = 0,
guidance_scale: float = 3.0,
):
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
def get_add_time_ids(
height: int,
width: int,
dynamic_crops: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if dynamic_crops:
random_scale = torch.rand(1).item() * 2 + 1
original_size = (int(height * random_scale), int(width * random_scale))
crops_coords_top_left = (
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
)
target_size = (height, width)
else:
original_size = (height, width)
crops_coords_top_left = (0, 0)
target_size = (height, width)
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
if device is not None:
add_time_ids = add_time_ids.to(device)
return add_time_ids
def predict_noise_xl(
unet,
scheduler,
timestep,
latents: torch.Tensor,
prompt_embeds: PromptEmbedsXL,
add_time_ids: torch.Tensor,
guidance_scale: float = 1.0,
):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
orig_size = add_time_ids[:, :2]
crop_size = add_time_ids[:, 2:4]
target_size = add_time_ids[:, 4:6]
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
def run_unet(model_input, text_embeds, vector_embeds):
return unet(model_input, timestep, text_embeds, vector_embeds)
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion_xl(
unet,
scheduler,
latents: torch.Tensor,
prompt_embeds: PromptEmbedsXL,
add_time_ids: torch.Tensor,
total_timesteps: int,
start_timesteps: int = 0,
guidance_scale: float = 3.0,
):
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
noise_pred = predict_noise_xl(
unet,
scheduler,
timestep,
latents,
prompt_embeds=prompt_embeds,
add_time_ids=add_time_ids,
guidance_scale=guidance_scale,
)
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
max_resolution = bucket_resolution
min_resolution = bucket_resolution // 2
step = 64
min_step = min_resolution // step
max_step = max_resolution // step
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
return height, width
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
height, width = prompt.get_resolution()
if prompt.dynamic_resolution and height == width:
return get_random_resolution_in_bucket(height)
return height, width

View File

@@ -4405,7 +4405,7 @@ def verify_command_line_training_args(args: argparse.Namespace):
def enable_high_vram(args: argparse.Namespace):
if args.highvram:
if getattr(args, "highvram", False):
logger.info("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
@@ -4418,10 +4418,10 @@ def verify_training_args(args: argparse.Namespace):
"""
enable_high_vram(args)
if args.v2 and args.clip_skip is not None:
if getattr(args, "v2", False) and getattr(args, "clip_skip", None) is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
if getattr(args, "cache_latents_to_disk", False) and not getattr(args, "cache_latents", False):
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
@@ -4440,32 +4440,32 @@ def verify_training_args(args: argparse.Namespace):
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
# )
if args.adaptive_noise_scale is not None and args.noise_offset is None:
if getattr(args, "adaptive_noise_scale", None) is not None and getattr(args, "noise_offset", None) is None:
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
if getattr(args, "scale_v_pred_loss_like_noise_pred", False) and not getattr(args, "v_parameterization", False):
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
if args.v_pred_like_loss and args.v_parameterization:
if getattr(args, "v_pred_like_loss", None) and getattr(args, "v_parameterization", False):
raise ValueError(
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
)
if args.zero_terminal_snr and not args.v_parameterization:
if getattr(args, "zero_terminal_snr", False) and not getattr(args, "v_parameterization", False):
logger.warning(
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
)
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
if getattr(args, "sample_every_n_epochs", None) is not None and args.sample_every_n_epochs <= 0:
logger.warning(
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
)
args.sample_every_n_epochs = None
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
if getattr(args, "sample_every_n_steps", None) is not None and args.sample_every_n_steps <= 0:
logger.warning(
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
)

402
sdxl_train_leco.py Normal file
View File

@@ -0,0 +1,402 @@
import argparse
import importlib
import json
import os
import random
from typing import Dict
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
from library.leco_train_util import (
PromptEmbedsCache,
apply_noise_offset,
concat_embeddings,
concat_embeddings_xl,
diffusion_xl,
encode_prompt_sdxl,
get_add_time_ids,
get_initial_latents,
get_random_resolution,
load_prompt_settings,
predict_noise_xl,
)
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_optimizer_arguments(parser)
train_util.add_training_arguments(parser, support_dreambooth=False)
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
add_logging_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml")
parser.add_argument(
"--max_denoising_steps",
type=int,
default=40,
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
)
parser.add_argument(
"--leco_denoise_guidance_scale",
type=float,
default=3.0,
help="guidance scale for the partial denoising pass / 部分デイズ時のguidance scale",
)
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
)
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
return parser
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
kwargs = {}
if args.network_args:
for net_arg in args.network_args:
key, value = net_arg.split("=", 1)
kwargs[key] = value
if "dropout" not in kwargs:
kwargs["dropout"] = args.network_dropout
return kwargs
def get_save_extension(args: argparse.Namespace) -> str:
if args.save_model_as == "ckpt":
return ".ckpt"
if args.save_model_as == "pt":
return ".pt"
return ".safetensors"
def save_weights(
accelerator,
network,
args: argparse.Namespace,
save_dtype,
prompt_settings,
global_step: int,
last: bool = False,
) -> None:
os.makedirs(args.output_dir, exist_ok=True)
ext = get_save_extension(args)
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata = None
if not args.no_metadata:
metadata = {
"ss_network_module": args.network_module,
"ss_network_dim": str(args.network_dim),
"ss_network_alpha": str(args.network_alpha),
"ss_leco_prompt_count": str(len(prompt_settings)),
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0,
}
if args.training_comment:
metadata["ss_training_comment"] = args.training_comment
metadata["ss_leco_preview"] = json.dumps(
[
{
"target": p.target,
"positive": p.positive,
"unconditional": p.unconditional,
"neutral": p.neutral,
"action": p.action,
"multiplier": p.multiplier,
"weight": p.weight,
}
for p in prompt_settings[:16]
],
ensure_ascii=False,
)
unwrapped = accelerator.unwrap_model(network)
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
logger.info(f"saved model to: {ckpt_file}")
def main():
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train_util.verify_training_args(args)
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
if args.output_dir is None:
raise ValueError("--output_dir is required")
if args.network_train_text_encoder_only:
raise ValueError("LECO does not support text encoder LoRA training")
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
set_seed(args.seed)
accelerator = train_util.prepare_accelerator(args)
weight_dtype, save_dtype = train_util.prepare_dtype(args)
prompt_settings = load_prompt_settings(args.prompts_file)
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
del vae
text_encoders = [text_encoder1, text_encoder2]
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.train()
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
prompt_cache = PromptEmbedsCache()
unique_prompts = sorted(
{
prompt
for setting in prompt_settings
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
}
)
with torch.no_grad():
for prompt in unique_prompts:
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
for text_encoder in text_encoders:
text_encoder.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
network_module = importlib.import_module(args.network_module)
net_kwargs = build_network_kwargs(args)
if args.dim_from_weights:
if args.network_weights is None:
raise ValueError("--dim_from_weights requires --network_weights")
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
else:
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
None,
text_encoders,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
network.set_multiplier(0.0)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
logger.info(f"loaded network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing()
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
optimizer_train_fn()
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
while global_step < args.max_train_steps:
with accelerator.accumulate(network):
optimizer.zero_grad(set_to_none=True)
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
height, width = get_random_resolution(setting)
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
accelerator.device, dtype=weight_dtype
)
latents = apply_noise_offset(latents, args.noise_offset)
add_time_ids = get_add_time_ids(
height,
width,
dynamic_crops=setting.dynamic_crops,
dtype=weight_dtype,
device=accelerator.device,
)
batched_time_ids = concat_embeddings(add_time_ids, add_time_ids, setting.batch_size)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
denoised_latents = diffusion_xl(
unet,
noise_scheduler,
latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
add_time_ids=batched_time_ids,
total_timesteps=timesteps_to,
guidance_scale=args.leco_denoise_guidance_scale,
)
noise_scheduler.set_timesteps(1000, device=accelerator.device)
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
network_multiplier.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
positive_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
neutral_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
unconditional_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
target_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() * setting.weight
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(0.0)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"guidance_scale": setting.guidance_scale,
"network_multiplier": setting.multiplier,
}
accelerator.log(logs, step=global_step)
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,114 @@
from pathlib import Path
import torch
from library.leco_train_util import load_prompt_settings
def test_load_prompt_settings_with_original_format(tmp_path: Path):
prompt_file = tmp_path / "prompts.yaml"
prompt_file.write_text(
"""
- target: "van gogh"
guidance_scale: 1.5
resolution: 512
""".strip(),
encoding="utf-8",
)
prompts = load_prompt_settings(prompt_file)
assert len(prompts) == 1
assert prompts[0].target == "van gogh"
assert prompts[0].positive == "van gogh"
assert prompts[0].unconditional == ""
assert prompts[0].neutral == ""
assert prompts[0].action == "erase"
assert prompts[0].guidance_scale == 1.5
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
prompt_file = tmp_path / "slider.yaml"
prompt_file.write_text(
"""
targets:
- target_class: ""
positive: "high detail"
negative: "low detail"
multiplier: 1.25
weight: 0.5
guidance_scale: 2.0
resolution: 768
neutral: ""
""".strip(),
encoding="utf-8",
)
prompts = load_prompt_settings(prompt_file)
assert len(prompts) == 4
first = prompts[0]
second = prompts[1]
third = prompts[2]
fourth = prompts[3]
assert first.target == ""
assert first.positive == "low detail"
assert first.unconditional == "high detail"
assert first.action == "erase"
assert first.multiplier == 1.25
assert first.weight == 0.5
assert first.get_resolution() == (768, 768)
assert second.positive == "high detail"
assert second.unconditional == "low detail"
assert second.action == "enhance"
assert second.multiplier == 1.25
assert third.action == "erase"
assert third.multiplier == -1.25
assert fourth.action == "enhance"
assert fourth.multiplier == -1.25
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
from library import sdxl_train_util
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
class DummyScheduler:
def scale_model_input(self, latent_model_input, timestep):
return latent_model_input
class DummyUNet:
def __call__(self, x, timesteps, context, y):
self.x = x
self.timesteps = timesteps
self.context = context
self.y = y
return torch.zeros_like(x)
latents = torch.randn(1, 4, 8, 8)
prompt_embeds = PromptEmbedsXL(
text_embeds=torch.randn(2, 77, 2048),
pooled_embeds=torch.randn(2, 1280),
)
add_time_ids = torch.tensor(
[
[1024, 1024, 0, 0, 1024, 1024],
[1024, 1024, 0, 0, 1024, 1024],
],
dtype=prompt_embeds.pooled_embeds.dtype,
)
unet = DummyUNet()
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
).to(prompt_embeds.pooled_embeds.dtype)
assert noise_pred.shape == latents.shape
assert unet.context is prompt_embeds.text_embeds
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))

View File

@@ -0,0 +1,16 @@
import sdxl_train_leco
from library import deepspeed_utils, sdxl_train_util, train_util
def test_syntax():
assert sdxl_train_leco is not None
def test_setup_parser_supports_shared_training_validation():
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
train_util.verify_training_args(args)
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
assert args.min_snr_gamma is None
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None

15
tests/test_train_leco.py Normal file
View File

@@ -0,0 +1,15 @@
import train_leco
from library import deepspeed_utils, train_util
def test_syntax():
assert train_leco is not None
def test_setup_parser_supports_shared_training_validation():
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
train_util.verify_training_args(args)
assert args.min_snr_gamma is None
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None

379
train_leco.py Normal file
View File

@@ -0,0 +1,379 @@
import argparse
import importlib
import json
import os
import random
from typing import Dict
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import custom_train_functions, strategy_sd, train_util
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
from library.leco_train_util import (
PromptEmbedsCache,
apply_noise_offset,
concat_embeddings,
diffusion,
encode_prompt_sd,
get_initial_latents,
get_random_resolution,
load_prompt_settings,
predict_noise,
)
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_optimizer_arguments(parser)
train_util.add_training_arguments(parser, support_dreambooth=False)
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
add_logging_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml")
parser.add_argument(
"--max_denoising_steps",
type=int,
default=40,
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
)
parser.add_argument(
"--leco_denoise_guidance_scale",
type=float,
default=3.0,
help="guidance scale for the partial denoising pass / 部分デイズ時のguidance scale",
)
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
)
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
return parser
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
kwargs = {}
if args.network_args:
for net_arg in args.network_args:
key, value = net_arg.split("=", 1)
kwargs[key] = value
if "dropout" not in kwargs:
kwargs["dropout"] = args.network_dropout
return kwargs
def get_save_extension(args: argparse.Namespace) -> str:
if args.save_model_as == "ckpt":
return ".ckpt"
if args.save_model_as == "pt":
return ".pt"
return ".safetensors"
def save_weights(
accelerator,
network,
args: argparse.Namespace,
save_dtype,
prompt_settings,
global_step: int,
last: bool = False,
) -> None:
os.makedirs(args.output_dir, exist_ok=True)
ext = get_save_extension(args)
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata = None
if not args.no_metadata:
metadata = {
"ss_network_module": args.network_module,
"ss_network_dim": str(args.network_dim),
"ss_network_alpha": str(args.network_alpha),
"ss_leco_prompt_count": str(len(prompt_settings)),
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
}
if args.training_comment:
metadata["ss_training_comment"] = args.training_comment
metadata["ss_leco_preview"] = json.dumps(
[
{
"target": p.target,
"positive": p.positive,
"unconditional": p.unconditional,
"neutral": p.neutral,
"action": p.action,
"multiplier": p.multiplier,
"weight": p.weight,
}
for p in prompt_settings[:16]
],
ensure_ascii=False,
)
unwrapped = accelerator.unwrap_model(network)
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
logger.info(f"saved model to: {ckpt_file}")
def main():
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train_util.verify_training_args(args)
if args.output_dir is None:
raise ValueError("--output_dir is required")
if args.network_train_text_encoder_only:
raise ValueError("LECO does not support text encoder LoRA training")
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
set_seed(args.seed)
accelerator = train_util.prepare_accelerator(args)
weight_dtype, save_dtype = train_util.prepare_dtype(args)
prompt_settings = load_prompt_settings(args.prompts_file)
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
del vae
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.train()
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
prompt_cache = PromptEmbedsCache()
unique_prompts = sorted(
{
prompt
for setting in prompt_settings
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
}
)
with torch.no_grad():
for prompt in unique_prompts:
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
text_encoder.to("cpu")
clean_memory_on_device(accelerator.device)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
network_module = importlib.import_module(args.network_module)
net_kwargs = build_network_kwargs(args)
if args.dim_from_weights:
if args.network_weights is None:
raise ValueError("--dim_from_weights requires --network_weights")
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
None,
text_encoder,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
network.set_multiplier(0.0)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
logger.info(f"loaded network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing()
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
optimizer_train_fn()
train_util.init_trackers(accelerator, args, "leco_train")
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
while global_step < args.max_train_steps:
with accelerator.accumulate(network):
optimizer.zero_grad(set_to_none=True)
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
height, width = get_random_resolution(setting)
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
accelerator.device, dtype=weight_dtype
)
latents = apply_noise_offset(latents, args.noise_offset)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
denoised_latents = diffusion(
unet,
noise_scheduler,
latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
total_timesteps=timesteps_to,
guidance_scale=args.leco_denoise_guidance_scale,
)
noise_scheduler.set_timesteps(1000, device=accelerator.device)
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
network_multiplier.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
positive_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
guidance_scale=1.0,
)
neutral_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
guidance_scale=1.0,
)
unconditional_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
guidance_scale=1.0,
)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
target_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
guidance_scale=1.0,
)
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() * setting.weight
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(0.0)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"guidance_scale": setting.guidance_scale,
"network_multiplier": setting.multiplier,
}
accelerator.log(logs, step=global_step)
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
accelerator.end_training()
if __name__ == "__main__":
main()