mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Merge 53e3263253 into b2abe873a5
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
.codex-tmp
|
||||
|
||||
93
docs/train_leco.md
Normal file
93
docs/train_leco.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# LECO Training
|
||||
|
||||
This repository now includes dedicated LECO training entry points:
|
||||
|
||||
- `train_leco.py` for Stable Diffusion 1.x / 2.x
|
||||
- `sdxl_train_leco.py` for SDXL
|
||||
|
||||
These scripts train a LoRA against the model's own noise predictions, so no image dataset is required.
|
||||
|
||||
## Current scope
|
||||
|
||||
- U-Net LoRA training only
|
||||
- `networks.lora` is the default network module
|
||||
- Prompt YAML supports both original LECO prompt pairs and ai-toolkit style slider targets
|
||||
- Full ai-toolkit job YAML is not supported; use a prompt/target YAML file only
|
||||
|
||||
## Example: SD 1.x / 2.x
|
||||
|
||||
```bash
|
||||
accelerate launch train_leco.py ^
|
||||
--pretrained_model_name_or_path="model.safetensors" ^
|
||||
--output_dir="output" ^
|
||||
--output_name="detail_slider" ^
|
||||
--prompts_file="prompts.yaml" ^
|
||||
--network_dim=8 ^
|
||||
--network_alpha=4 ^
|
||||
--learning_rate=1e-4 ^
|
||||
--max_train_steps=500 ^
|
||||
--max_denoising_steps=40 ^
|
||||
--mixed_precision=bf16
|
||||
```
|
||||
|
||||
## Example: SDXL
|
||||
|
||||
```bash
|
||||
accelerate launch sdxl_train_leco.py ^
|
||||
--pretrained_model_name_or_path="sdxl_model.safetensors" ^
|
||||
--output_dir="output" ^
|
||||
--output_name="detail_slider_xl" ^
|
||||
--prompts_file="slider.yaml" ^
|
||||
--network_dim=8 ^
|
||||
--network_alpha=4 ^
|
||||
--learning_rate=1e-4 ^
|
||||
--max_train_steps=500 ^
|
||||
--max_denoising_steps=40 ^
|
||||
--mixed_precision=bf16
|
||||
```
|
||||
|
||||
## Prompt YAML: original LECO format
|
||||
|
||||
```yaml
|
||||
- target: "van gogh"
|
||||
positive: "van gogh"
|
||||
unconditional: ""
|
||||
neutral: ""
|
||||
action: "erase"
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
batch_size: 1
|
||||
multiplier: 1.0
|
||||
weight: 1.0
|
||||
```
|
||||
|
||||
## Prompt YAML: ai-toolkit style slider target
|
||||
|
||||
This expands internally into the bidirectional LECO pairs needed for slider-style behavior.
|
||||
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail, intricate, high quality"
|
||||
negative: "blurry, low detail, low quality"
|
||||
multiplier: 1.0
|
||||
weight: 1.0
|
||||
|
||||
guidance_scale: 1.0
|
||||
resolution: 512
|
||||
neutral: ""
|
||||
```
|
||||
|
||||
You can also provide multiple neutral prompts:
|
||||
|
||||
```yaml
|
||||
targets:
|
||||
- target_class: "person"
|
||||
positive: "smiling person"
|
||||
negative: "expressionless person"
|
||||
|
||||
neutrals:
|
||||
- ""
|
||||
- "studio photo"
|
||||
- "cinematic lighting"
|
||||
```
|
||||
@@ -62,7 +62,7 @@ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def prepare_deepspeed_args(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
if not getattr(args, "deepspeed", False):
|
||||
return
|
||||
|
||||
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
@@ -70,7 +70,7 @@ def prepare_deepspeed_args(args: argparse.Namespace):
|
||||
|
||||
|
||||
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
if not getattr(args, "deepspeed", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
|
||||
542
library/leco_train_util.py
Normal file
542
library/leco_train_util.py
Normal file
@@ -0,0 +1,542 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
from library import sdxl_train_util
|
||||
|
||||
ResolutionValue = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptEmbedsXL:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: torch.Tensor
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
def __init__(self):
|
||||
self.prompts: dict[str, Any] = {}
|
||||
|
||||
def __setitem__(self, name: str, value: Any) -> None:
|
||||
self.prompts[name] = value
|
||||
|
||||
def __getitem__(self, name: str) -> Optional[Any]:
|
||||
return self.prompts.get(name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptSettings:
|
||||
target: str
|
||||
positive: Optional[str] = None
|
||||
unconditional: str = ""
|
||||
neutral: Optional[str] = None
|
||||
action: str = "erase"
|
||||
guidance_scale: float = 1.0
|
||||
resolution: ResolutionValue = 512
|
||||
dynamic_resolution: bool = False
|
||||
batch_size: int = 1
|
||||
dynamic_crops: bool = False
|
||||
multiplier: float = 1.0
|
||||
weight: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.positive is None:
|
||||
self.positive = self.target
|
||||
if self.neutral is None:
|
||||
self.neutral = self.unconditional
|
||||
if self.action not in ("erase", "enhance"):
|
||||
raise ValueError(f"Invalid action: {self.action}")
|
||||
|
||||
self.guidance_scale = float(self.guidance_scale)
|
||||
self.batch_size = int(self.batch_size)
|
||||
self.multiplier = float(self.multiplier)
|
||||
self.weight = float(self.weight)
|
||||
self.dynamic_resolution = bool(self.dynamic_resolution)
|
||||
self.dynamic_crops = bool(self.dynamic_crops)
|
||||
self.resolution = normalize_resolution(self.resolution)
|
||||
|
||||
def get_resolution(self) -> Tuple[int, int]:
|
||||
if isinstance(self.resolution, tuple):
|
||||
return self.resolution
|
||||
return (self.resolution, self.resolution)
|
||||
|
||||
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
|
||||
offset = self.guidance_scale * (positive_latents - unconditional_latents)
|
||||
if self.action == "erase":
|
||||
return neutral_latents - offset
|
||||
return neutral_latents + offset
|
||||
|
||||
|
||||
def normalize_resolution(value: Any) -> ResolutionValue:
|
||||
if isinstance(value, tuple):
|
||||
if len(value) != 2:
|
||||
raise ValueError(f"resolution tuple must have 2 items: {value}")
|
||||
return (int(value[0]), int(value[1]))
|
||||
if isinstance(value, list):
|
||||
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
|
||||
return (int(value[0]), int(value[1]))
|
||||
raise ValueError(f"resolution list must have 2 numeric items: {value}")
|
||||
return int(value)
|
||||
|
||||
|
||||
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return [line.strip() for line in f.readlines() if line.strip()]
|
||||
|
||||
|
||||
def _recognized_prompt_keys() -> set[str]:
|
||||
return {
|
||||
"target",
|
||||
"positive",
|
||||
"unconditional",
|
||||
"neutral",
|
||||
"action",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _recognized_slider_keys() -> set[str]:
|
||||
return {
|
||||
"target_class",
|
||||
"positive",
|
||||
"negative",
|
||||
"neutral",
|
||||
"guidance_scale",
|
||||
"resolution",
|
||||
"resolutions",
|
||||
"dynamic_resolution",
|
||||
"batch_size",
|
||||
"dynamic_crops",
|
||||
"multiplier",
|
||||
"weight",
|
||||
}
|
||||
|
||||
|
||||
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
|
||||
merged = {k: v for k, v in defaults.items() if k in known_keys}
|
||||
merged.update(item)
|
||||
return merged
|
||||
|
||||
|
||||
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
|
||||
if value is None:
|
||||
return [512]
|
||||
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
|
||||
return [normalize_resolution(v) for v in value]
|
||||
return [normalize_resolution(value)]
|
||||
|
||||
|
||||
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
|
||||
target_class = str(target.get("target_class", ""))
|
||||
positive = str(target.get("positive", "") or "")
|
||||
negative = str(target.get("negative", "") or "")
|
||||
guidance_scale = target.get("guidance_scale", 1.0)
|
||||
dynamic_resolution = target.get("dynamic_resolution", False)
|
||||
batch_size = target.get("batch_size", 1)
|
||||
dynamic_crops = target.get("dynamic_crops", False)
|
||||
multiplier = target.get("multiplier", 1.0)
|
||||
weight = target.get("weight", 1.0)
|
||||
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
|
||||
|
||||
if not positive.strip() and not negative.strip():
|
||||
raise ValueError("slider target requires either positive or negative prompt")
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
for resolution in resolutions:
|
||||
if positive.strip() and negative.strip():
|
||||
prompt_settings.extend(
|
||||
[
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=negative,
|
||||
unconditional=positive,
|
||||
neutral=neutral,
|
||||
action="erase",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=positive,
|
||||
unconditional=negative,
|
||||
neutral=neutral,
|
||||
action="enhance",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=positive,
|
||||
unconditional=negative,
|
||||
neutral=neutral,
|
||||
action="erase",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=-multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=negative,
|
||||
unconditional=positive,
|
||||
neutral=neutral,
|
||||
action="enhance",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=-multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
]
|
||||
)
|
||||
elif negative.strip():
|
||||
prompt_settings.extend(
|
||||
[
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=negative,
|
||||
unconditional="",
|
||||
neutral=neutral,
|
||||
action="erase",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=negative,
|
||||
unconditional="",
|
||||
neutral=neutral,
|
||||
action="enhance",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=-multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
prompt_settings.extend(
|
||||
[
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=positive,
|
||||
unconditional="",
|
||||
neutral=neutral,
|
||||
action="enhance",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
PromptSettings(
|
||||
target=target_class,
|
||||
positive=positive,
|
||||
unconditional="",
|
||||
neutral=neutral,
|
||||
action="erase",
|
||||
guidance_scale=guidance_scale,
|
||||
resolution=resolution,
|
||||
dynamic_resolution=dynamic_resolution,
|
||||
batch_size=batch_size,
|
||||
dynamic_crops=dynamic_crops,
|
||||
multiplier=-multiplier,
|
||||
weight=weight,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
path = Path(path)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
raise ValueError("prompt file is empty")
|
||||
|
||||
default_prompt_values = {
|
||||
"guidance_scale": 1.0,
|
||||
"resolution": 512,
|
||||
"dynamic_resolution": False,
|
||||
"batch_size": 1,
|
||||
"dynamic_crops": False,
|
||||
"multiplier": 1.0,
|
||||
"weight": 1.0,
|
||||
}
|
||||
|
||||
prompt_settings: List[PromptSettings] = []
|
||||
|
||||
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
|
||||
prompt_settings.append(PromptSettings(**merged))
|
||||
|
||||
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
|
||||
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
|
||||
if not neutral_values:
|
||||
neutral_values = [str(merged.get("neutral", "") or "")]
|
||||
for neutral in neutral_values:
|
||||
prompt_settings.extend(_expand_slider_target(merged, neutral))
|
||||
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, default_prompt_values)
|
||||
elif isinstance(data, dict):
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
for item in data["prompts"]:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
slider_config = data.get("slider", data)
|
||||
targets = slider_config.get("targets")
|
||||
if targets is None:
|
||||
if "target_class" in slider_config:
|
||||
targets = [slider_config]
|
||||
elif "target" in slider_config:
|
||||
targets = [slider_config]
|
||||
else:
|
||||
raise ValueError("prompt file does not contain prompts or slider targets")
|
||||
if len(targets) == 0:
|
||||
raise ValueError("prompt file contains an empty targets list")
|
||||
|
||||
if "target" in targets[0]:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
|
||||
for item in targets:
|
||||
append_prompt_item(item, defaults)
|
||||
else:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
|
||||
neutral_values: List[str] = []
|
||||
if "neutrals" in slider_config:
|
||||
neutral_values.extend(str(v) for v in slider_config["neutrals"])
|
||||
if "neutral_prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
|
||||
if "prompt_file" in slider_config:
|
||||
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
|
||||
if not neutral_values:
|
||||
neutral_values = [str(slider_config.get("neutral", "") or "")]
|
||||
|
||||
for item in targets:
|
||||
item_neutrals = neutral_values
|
||||
if "neutrals" in item:
|
||||
item_neutrals = [str(v) for v in item["neutrals"]]
|
||||
elif "neutral_prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
|
||||
elif "prompt_file" in item:
|
||||
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
|
||||
elif "neutral" in item:
|
||||
item_neutrals = [str(item["neutral"] or "")]
|
||||
|
||||
append_slider_item(item, defaults, item_neutrals)
|
||||
else:
|
||||
raise ValueError("prompt file must be a list or mapping")
|
||||
|
||||
if not prompt_settings:
|
||||
raise ValueError("no prompt settings found")
|
||||
|
||||
return prompt_settings
|
||||
|
||||
|
||||
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
|
||||
|
||||
|
||||
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
|
||||
tokens = tokenize_strategy.tokenize(prompt)
|
||||
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
|
||||
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
|
||||
|
||||
|
||||
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
|
||||
if noise_offset is None:
|
||||
return latents
|
||||
return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
|
||||
|
||||
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
|
||||
noise = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
device="cpu",
|
||||
).repeat(n_prompts, 1, 1, 1)
|
||||
return noise * scheduler.init_noise_sigma
|
||||
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(
|
||||
batch_size, dim=0
|
||||
)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
def _run_with_checkpoint(function, *args):
|
||||
if torch.is_grad_enabled():
|
||||
return checkpoint(function, *args, use_reentrant=False)
|
||||
return function(*args)
|
||||
|
||||
|
||||
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
def run_unet(model_input, encoder_hidden_states):
|
||||
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
def diffusion(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_add_time_ids(
|
||||
height: int,
|
||||
width: int,
|
||||
dynamic_crops: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if dynamic_crops:
|
||||
random_scale = torch.rand(1).item() * 2 + 1
|
||||
original_size = (int(height * random_scale), int(width * random_scale))
|
||||
crops_coords_top_left = (
|
||||
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
|
||||
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
|
||||
)
|
||||
target_size = (height, width)
|
||||
else:
|
||||
original_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (height, width)
|
||||
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
|
||||
if device is not None:
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
return add_time_ids
|
||||
|
||||
|
||||
def predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
guidance_scale: float = 1.0,
|
||||
):
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
orig_size = add_time_ids[:, :2]
|
||||
crop_size = add_time_ids[:, 2:4]
|
||||
target_size = add_time_ids[:, 4:6]
|
||||
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
|
||||
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
|
||||
|
||||
def run_unet(model_input, text_embeds, vector_embeds):
|
||||
return unet(model_input, timestep, text_embeds, vector_embeds)
|
||||
|
||||
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
latents: torch.Tensor,
|
||||
prompt_embeds: PromptEmbedsXL,
|
||||
add_time_ids: torch.Tensor,
|
||||
total_timesteps: int,
|
||||
start_timesteps: int = 0,
|
||||
guidance_scale: float = 3.0,
|
||||
):
|
||||
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
||||
noise_pred = predict_noise_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
timestep,
|
||||
latents,
|
||||
prompt_embeds=prompt_embeds,
|
||||
add_time_ids=add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
return latents
|
||||
|
||||
|
||||
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
|
||||
max_resolution = bucket_resolution
|
||||
min_resolution = bucket_resolution // 2
|
||||
step = 64
|
||||
min_step = min_resolution // step
|
||||
max_step = max_resolution // step
|
||||
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
|
||||
return height, width
|
||||
|
||||
|
||||
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
|
||||
height, width = prompt.get_resolution()
|
||||
if prompt.dynamic_resolution and height == width:
|
||||
return get_random_resolution_in_bucket(height)
|
||||
return height, width
|
||||
@@ -4405,7 +4405,7 @@ def verify_command_line_training_args(args: argparse.Namespace):
|
||||
|
||||
|
||||
def enable_high_vram(args: argparse.Namespace):
|
||||
if args.highvram:
|
||||
if getattr(args, "highvram", False):
|
||||
logger.info("highvram is enabled / highvramが有効です")
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
@@ -4418,10 +4418,10 @@ def verify_training_args(args: argparse.Namespace):
|
||||
"""
|
||||
enable_high_vram(args)
|
||||
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
if getattr(args, "v2", False) and getattr(args, "clip_skip", None) is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
if getattr(args, "cache_latents_to_disk", False) and not getattr(args, "cache_latents", False):
|
||||
args.cache_latents = True
|
||||
logger.warning(
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
@@ -4440,32 +4440,32 @@ def verify_training_args(args: argparse.Namespace):
|
||||
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
|
||||
# )
|
||||
|
||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||
if getattr(args, "adaptive_noise_scale", None) is not None and getattr(args, "noise_offset", None) is None:
|
||||
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||
|
||||
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
|
||||
if getattr(args, "scale_v_pred_loss_like_noise_pred", False) and not getattr(args, "v_parameterization", False):
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
if args.v_pred_like_loss and args.v_parameterization:
|
||||
if getattr(args, "v_pred_like_loss", None) and getattr(args, "v_parameterization", False):
|
||||
raise ValueError(
|
||||
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
|
||||
)
|
||||
|
||||
if args.zero_terminal_snr and not args.v_parameterization:
|
||||
if getattr(args, "zero_terminal_snr", False) and not getattr(args, "v_parameterization", False):
|
||||
logger.warning(
|
||||
f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected"
|
||||
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
|
||||
)
|
||||
|
||||
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
|
||||
if getattr(args, "sample_every_n_epochs", None) is not None and args.sample_every_n_epochs <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
args.sample_every_n_epochs = None
|
||||
|
||||
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
|
||||
if getattr(args, "sample_every_n_steps", None) is not None and args.sample_every_n_steps <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
|
||||
402
sdxl_train_leco.py
Normal file
402
sdxl_train_leco.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
concat_embeddings,
|
||||
concat_embeddings_xl,
|
||||
diffusion_xl,
|
||||
encode_prompt_sdxl,
|
||||
get_add_time_ids,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise_xl,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0,
|
||||
}
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
|
||||
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
|
||||
)
|
||||
del vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
|
||||
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoders,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
add_time_ids = get_add_time_ids(
|
||||
height,
|
||||
width,
|
||||
dynamic_crops=setting.dynamic_crops,
|
||||
dtype=weight_dtype,
|
||||
device=accelerator.device,
|
||||
)
|
||||
batched_time_ids = concat_embeddings(add_time_ids, add_time_ids, setting.batch_size)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise_xl(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
add_time_ids=batched_time_ids,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
tests/library/test_leco_train_util.py
Normal file
114
tests/library/test_leco_train_util.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from library.leco_train_util import load_prompt_settings
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_original_format(tmp_path: Path):
|
||||
prompt_file = tmp_path / "prompts.yaml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
- target: "van gogh"
|
||||
guidance_scale: 1.5
|
||||
resolution: 512
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 1
|
||||
assert prompts[0].target == "van gogh"
|
||||
assert prompts[0].positive == "van gogh"
|
||||
assert prompts[0].unconditional == ""
|
||||
assert prompts[0].neutral == ""
|
||||
assert prompts[0].action == "erase"
|
||||
assert prompts[0].guidance_scale == 1.5
|
||||
|
||||
|
||||
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
|
||||
prompt_file = tmp_path / "slider.yaml"
|
||||
prompt_file.write_text(
|
||||
"""
|
||||
targets:
|
||||
- target_class: ""
|
||||
positive: "high detail"
|
||||
negative: "low detail"
|
||||
multiplier: 1.25
|
||||
weight: 0.5
|
||||
guidance_scale: 2.0
|
||||
resolution: 768
|
||||
neutral: ""
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
prompts = load_prompt_settings(prompt_file)
|
||||
|
||||
assert len(prompts) == 4
|
||||
|
||||
first = prompts[0]
|
||||
second = prompts[1]
|
||||
third = prompts[2]
|
||||
fourth = prompts[3]
|
||||
|
||||
assert first.target == ""
|
||||
assert first.positive == "low detail"
|
||||
assert first.unconditional == "high detail"
|
||||
assert first.action == "erase"
|
||||
assert first.multiplier == 1.25
|
||||
assert first.weight == 0.5
|
||||
assert first.get_resolution() == (768, 768)
|
||||
|
||||
assert second.positive == "high detail"
|
||||
assert second.unconditional == "low detail"
|
||||
assert second.action == "enhance"
|
||||
assert second.multiplier == 1.25
|
||||
|
||||
assert third.action == "erase"
|
||||
assert third.multiplier == -1.25
|
||||
|
||||
assert fourth.action == "enhance"
|
||||
assert fourth.multiplier == -1.25
|
||||
|
||||
|
||||
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
|
||||
from library import sdxl_train_util
|
||||
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
|
||||
|
||||
class DummyScheduler:
|
||||
def scale_model_input(self, latent_model_input, timestep):
|
||||
return latent_model_input
|
||||
|
||||
class DummyUNet:
|
||||
def __call__(self, x, timesteps, context, y):
|
||||
self.x = x
|
||||
self.timesteps = timesteps
|
||||
self.context = context
|
||||
self.y = y
|
||||
return torch.zeros_like(x)
|
||||
|
||||
latents = torch.randn(1, 4, 8, 8)
|
||||
prompt_embeds = PromptEmbedsXL(
|
||||
text_embeds=torch.randn(2, 77, 2048),
|
||||
pooled_embeds=torch.randn(2, 1280),
|
||||
)
|
||||
add_time_ids = torch.tensor(
|
||||
[
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
[1024, 1024, 0, 0, 1024, 1024],
|
||||
],
|
||||
dtype=prompt_embeds.pooled_embeds.dtype,
|
||||
)
|
||||
|
||||
unet = DummyUNet()
|
||||
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
|
||||
|
||||
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
|
||||
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
|
||||
).to(prompt_embeds.pooled_embeds.dtype)
|
||||
|
||||
assert noise_pred.shape == latents.shape
|
||||
assert unet.context is prompt_embeds.text_embeds
|
||||
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
|
||||
16
tests/test_sdxl_train_leco.py
Normal file
16
tests/test_sdxl_train_leco.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sdxl_train_leco
|
||||
from library import deepspeed_utils, sdxl_train_util, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert sdxl_train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
15
tests/test_train_leco.py
Normal file
15
tests/test_train_leco.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import train_leco
|
||||
from library import deepspeed_utils, train_util
|
||||
|
||||
|
||||
def test_syntax():
|
||||
assert train_leco is not None
|
||||
|
||||
|
||||
def test_setup_parser_supports_shared_training_validation():
|
||||
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
assert args.min_snr_gamma is None
|
||||
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
|
||||
379
train_leco.py
Normal file
379
train_leco.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import custom_train_functions, strategy_sd, train_util
|
||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
|
||||
from library.leco_train_util import (
|
||||
PromptEmbedsCache,
|
||||
apply_noise_offset,
|
||||
concat_embeddings,
|
||||
diffusion,
|
||||
encode_prompt_sd,
|
||||
get_initial_latents,
|
||||
get_random_resolution,
|
||||
load_prompt_settings,
|
||||
predict_noise,
|
||||
)
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
train_util.add_training_arguments(parser, support_dreambooth=False)
|
||||
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
default="safetensors",
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
|
||||
|
||||
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt yaml / LECO用のprompt yaml")
|
||||
parser.add_argument(
|
||||
"--max_denoising_steps",
|
||||
type=int,
|
||||
default=40,
|
||||
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--leco_denoise_guidance_scale",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
|
||||
)
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
|
||||
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
|
||||
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
|
||||
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
|
||||
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only",
|
||||
action="store_true",
|
||||
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
|
||||
)
|
||||
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
|
||||
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
|
||||
kwargs = {}
|
||||
if args.network_args:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=", 1)
|
||||
kwargs[key] = value
|
||||
if "dropout" not in kwargs:
|
||||
kwargs["dropout"] = args.network_dropout
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_save_extension(args: argparse.Namespace) -> str:
|
||||
if args.save_model_as == "ckpt":
|
||||
return ".ckpt"
|
||||
if args.save_model_as == "pt":
|
||||
return ".pt"
|
||||
return ".safetensors"
|
||||
|
||||
|
||||
def save_weights(
|
||||
accelerator,
|
||||
network,
|
||||
args: argparse.Namespace,
|
||||
save_dtype,
|
||||
prompt_settings,
|
||||
global_step: int,
|
||||
last: bool = False,
|
||||
) -> None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ext = get_save_extension(args)
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
metadata = None
|
||||
if not args.no_metadata:
|
||||
metadata = {
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": str(args.network_dim),
|
||||
"ss_network_alpha": str(args.network_alpha),
|
||||
"ss_leco_prompt_count": str(len(prompt_settings)),
|
||||
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
|
||||
}
|
||||
if args.training_comment:
|
||||
metadata["ss_training_comment"] = args.training_comment
|
||||
metadata["ss_leco_preview"] = json.dumps(
|
||||
[
|
||||
{
|
||||
"target": p.target,
|
||||
"positive": p.positive,
|
||||
"unconditional": p.unconditional,
|
||||
"neutral": p.neutral,
|
||||
"action": p.action,
|
||||
"multiplier": p.multiplier,
|
||||
"weight": p.weight,
|
||||
}
|
||||
for p in prompt_settings[:16]
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
unwrapped = accelerator.unwrap_model(network)
|
||||
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
|
||||
logger.info(f"saved model to: {ckpt_file}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
train_util.verify_training_args(args)
|
||||
|
||||
if args.output_dir is None:
|
||||
raise ValueError("--output_dir is required")
|
||||
if args.network_train_text_encoder_only:
|
||||
raise ValueError("LECO does not support text encoder LoRA training")
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32 - 1)
|
||||
set_seed(args.seed)
|
||||
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
prompt_settings = load_prompt_settings(args.prompts_file)
|
||||
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
del vae
|
||||
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.train()
|
||||
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
|
||||
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
prompt_cache = PromptEmbedsCache()
|
||||
unique_prompts = sorted(
|
||||
{
|
||||
prompt
|
||||
for setting in prompt_settings
|
||||
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
|
||||
}
|
||||
)
|
||||
with torch.no_grad():
|
||||
for prompt in unique_prompts:
|
||||
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
|
||||
|
||||
text_encoder.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
if args.zero_terminal_snr:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
net_kwargs = build_network_kwargs(args)
|
||||
if args.dim_from_weights:
|
||||
if args.network_weights is None:
|
||||
raise ValueError("--dim_from_weights requires --network_weights")
|
||||
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
args.network_alpha,
|
||||
None,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout=args.network_dropout,
|
||||
**net_kwargs,
|
||||
)
|
||||
|
||||
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
|
||||
network.set_multiplier(0.0)
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
logger.info(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
network.enable_gradient_checkpointing()
|
||||
|
||||
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
|
||||
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
optimizer_train_fn()
|
||||
train_util.init_trackers(accelerator, args, "leco_train")
|
||||
|
||||
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
while global_step < args.max_train_steps:
|
||||
with accelerator.accumulate(network):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
|
||||
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
|
||||
|
||||
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
|
||||
height, width = get_random_resolution(setting)
|
||||
|
||||
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
latents = apply_noise_offset(latents, args.noise_offset)
|
||||
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
denoised_latents = diffusion(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
total_timesteps=timesteps_to,
|
||||
guidance_scale=args.leco_denoise_guidance_scale,
|
||||
)
|
||||
|
||||
noise_scheduler.set_timesteps(1000, device=accelerator.device)
|
||||
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
|
||||
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
positive_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
neutral_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
unconditional_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
network_multiplier.set_multiplier(setting.multiplier)
|
||||
with accelerator.autocast():
|
||||
target_latents = predict_noise(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
current_timestep,
|
||||
denoised_latents,
|
||||
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
|
||||
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=(1, 2, 3))
|
||||
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
|
||||
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
loss = loss.mean() * setting.weight
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
network_multiplier = accelerator.unwrap_model(network)
|
||||
network_multiplier.set_multiplier(0.0)
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
"guidance_scale": setting.guidance_scale,
|
||||
"network_multiplier": setting.multiplier,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
|
||||
|
||||
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user