Compare commits

...

26 Commits

Author SHA1 Message Date
Qing Long
77e08c5400 Merge 201e1997a2 into 1dae34b0af 2026-03-30 22:41:30 +00:00
Kohya S.
1dae34b0af Merge pull request #2298 from kohya-ss/sd3
Merge development changes into main
2026-03-30 07:53:22 +09:00
Kohya S.
dd7a666727 Merge pull request #2301 from woct0rdho/resize-lora-logging
Print verbose info while resizing LoRA is running
2026-03-30 07:50:15 +09:00
woctordho
b2c330407b Print verbose info while extracting 2026-03-29 21:36:33 +08:00
Kohya S.
c018765583 Merge pull request #2300 from kohya-ss/doc/re-update-README
doc: update change history in README files to include LECO training
2026-03-29 22:10:53 +09:00
Kohya S
3cb9025b4b doc: update change history in README files to include LECO training support for SD/SDXL 2026-03-29 22:07:52 +09:00
Kohya S.
adf4b7b9c0 Merge pull request #2299 from kohya-ss/docs/fix-missing-docs
Improve clarity of README table of contents and change history
2026-03-29 22:02:43 +09:00
Kohya S
b637c31365 fix: update table of contents and change history in README files for clarity 2026-03-29 21:58:38 +09:00
Kohya S.
7cbae516c1 Merge pull request #2297 from kohya-ss/fix-anima-fp16-nan-issue
fix: AdaLN modulation to use float32 for numerical stability in fp16
2026-03-29 21:31:19 +09:00
Kohya S
5fb3172baf fix: AdaLN modulation to use float32 for numerical stability in fp16 2026-03-29 21:25:53 +09:00
Kohya S.
5cdad10de5 Fix/leco cleanup (#2294)
* feat: SD1.x/2.x と SDXL 向けの LECO 学習スクリプトを追加 (#2285)

* Add LECO training script and associated tests

- Implemented `sdxl_train_leco.py` for training with LECO prompts, including argument parsing, model setup, training loop, and weight saving functionality.
- Created unit tests for `load_prompt_settings` in `test_leco_train_util.py` to validate loading of prompt configurations in both original and slider formats.
- Added basic syntax tests for `train_leco.py` and `sdxl_train_leco.py` to ensure modules are importable.

* fix: use getattr for safe attribute access in argument verification

* feat: add CUDA device compatibility validation and corresponding tests

* Revert "feat: add CUDA device compatibility validation and corresponding tests"

This reverts commit 6d3e51431b.

* feat: update predict_noise_xl to use vector embedding from add_time_ids

* feat: implement checkpointing in predict_noise and predict_noise_xl functions

* feat: remove unused submodules and update .gitignore to exclude .codex-tmp

---------

Co-authored-by: Kohya S. <52813779+kohya-ss@users.noreply.github.com>

* fix: format

* fix: LECO PR #2285 のレビュー指摘事項を修正

- train_util.py/deepspeed_utils.py の getattr 化を元に戻し、LECO パーサーにダミー引数を追加
- sdxl_train_util のモジュールレベルインポートをローカルインポートに変更
- PromptEmbedsCache.__getitem__ でキャッシュミス時に KeyError を送出するよう修正
- 設定ファイル形式を YAML から TOML に変更(リポジトリの規約に統一)
- 重複コード (build_network_kwargs, get_save_extension, save_weights) を leco_train_util.py に統合
- _expand_slider_target の冗長な PromptSettings 構築を簡素化
- add_time_ids 用に専用の batch_add_time_ids 関数を追加

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: LECO 学習ガイドを大幅に拡充

コマンドライン引数の全カテゴリ別解説、プロンプト TOML の全フィールド説明、
2つの guidance_scale の違い、推奨設定表、YAML からの変換ガイド等を追加。
英語本文と日本語折り畳みの二言語構成。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: apply_noise_offset の dtype 不一致を修正

torch.randn のデフォルト float32 により latents が暗黙的にアップキャストされる問題を修正。
float32/CPU で生成後に latents の dtype/device へ変換する安全なパターンを採用。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Umisetokikaze <52318966+umisetokikaze@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-29 20:41:43 +09:00
Kohya S.
89b246f3f6 Merge pull request #2296 from kohya-ss/add-svd-lowrank-niter
Add --svd_lowrank_niter option to resize_lora.py
2026-03-29 20:40:55 +09:00
woctordho
4be0e94fad Merge pull request #2194 from woct0rdho/rank1
Fix the 'off by 1' problem in dynamically resized LoRA rank
2026-03-29 20:35:00 +09:00
Kohya S
0e168dd1eb add --svd_lowrank_niter option to resize_lora.py
Allow users to control the number of iterations for torch.svd_lowrank
on large matrices. Default is 2 (matching PR #2240 behavior). Set to 0
to disable svd_lowrank and use full SVD instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-29 20:33:33 +09:00
Kohya S.
2723a75f91 Merge pull request #2240 from woct0rdho/svd-lowrank
Use `torch.svd_lowrank` for large matrices in `resize_lora.py`
2026-03-29 19:54:44 +09:00
Kohya S.
5f793fb0f4 Log d*lr for ProdigyPlusScheduleFree (#2289) 2026-03-29 18:47:09 +09:00
Kohya S.
feb38356ea Merge pull request #2291 from kohya-ss/fix-anima-validation-with-text-encoder-output-cache
fix: Anima validation dataset not working with Text Encoder output cache
2026-03-22 22:23:27 +09:00
Kohya S
cdb49f9fe7 fix: Anima validation dataset not working with Text Encoder output caching due to caption dropout 2026-03-22 22:19:47 +09:00
Kohya S
bd19e4c15d Merge branch 'main' into sd3 2026-03-22 21:10:51 +09:00
woctordho
343c929e39 Log d*lr for ProdigyPlusScheduleFree 2026-03-21 11:09:56 +08:00
Kohya S.
b2abe873a5 Merge pull request #2283 from kozistr/deps/pytorch-optimizer
Bump `pytorch-optimizer` into 3.10.0
2026-03-19 09:18:06 +09:00
Kohya S.
7c159291e9 docs: add skip_image_resolution to config README (#2288)
* docs: add skip_image_resolution option to config README

Document the skip_image_resolution dataset option added in PR #2273.
Add option description, multi-resolution dataset TOML example, and
command-line argument entry to both Japanese and English config READMEs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* docs: clarify `skip_image_resolution` functionality in dataset config

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 09:17:29 +09:00
woctordho
1cd95b2d8b Add skip_image_resolution to deduplicate multi-resolution dataset (#2273)
* Add min_orig_resolution and max_orig_resolution

* Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution

* Change skip_image_resolution to tuple

* Move filtering to __init__

* Minor fix
2026-03-19 08:43:39 +09:00
kozistr
1bd0b0faf1 build(deps): bump pytorch-optimizer into 3.10.0 2026-03-02 14:39:48 +09:00
woctordho
872124c5e1 Use svd_lowrank for large matrices in resize_lora.py 2025-11-17 10:14:17 +08:00
sdbds
201e1997a2 init 2025-05-26 10:03:30 +08:00
21 changed files with 3361 additions and 120 deletions

3
.gitignore vendored
View File

@@ -11,4 +11,5 @@ GEMINI.md
.claude
.gemini
MagicMock
references
.codex-tmp
references

View File

@@ -8,25 +8,25 @@
<summary>クリックすると展開します</summary>
- [はじめに](#はじめに)
- [スポンサー](#スポンサー)
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
- [更新履歴](#更新履歴)
- [サポートモデル](#サポートモデル)
- [機能](#機能)
- [スポンサー](#スポンサー)
- [スポンサー募集のお知らせ](#スポンサー募集のお知らせ)
- [更新履歴](#更新履歴)
- [サポートモデル](#サポートモデル)
- [機能](#機能)
- [ドキュメント](#ドキュメント)
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
- [その他のドキュメント](#その他のドキュメント)
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
- [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語)
- [その他のドキュメント](#その他のドキュメント)
- [旧ドキュメント(日本語)](#旧ドキュメント日本語)
- [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ)
- [Windows環境でのインストール](#windows環境でのインストール)
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
- [インストール手順](#インストール手順)
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
- [xformersのインストールオプション](#xformersのインストールオプション)
- [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム)
- [インストール手順](#インストール手順)
- [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて)
- [xformersのインストールオプション](#xformersのインストールオプション)
- [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール)
- [DeepSpeedのインストール実験的、LinuxまたはWSL2のみ](#deepspeedのインストール実験的linuxまたはwsl2のみ)
- [DeepSpeedのインストール実験的、LinuxまたはWSL2のみ](#deepspeedのインストール実験的linuxまたはwsl2のみ)
- [アップグレード](#アップグレード)
- [PyTorchのアップグレード](#pytorchのアップグレード)
- [PyTorchのアップグレード](#pytorchのアップグレード)
- [謝意](#謝意)
- [ライセンス](#ライセンス)
@@ -50,15 +50,28 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
- `networks/resize_lora.py``torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。
- デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できますデフォルトは2、多いほど精度が向上します。0にすると従来の方法になります。詳細は `--help` でご確認ください。
- LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275)
- 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。
- マルチ解像度データセット同じ画像を複数のbucketサイズにリサイズして使用がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273)
- woct0rdho氏に感謝します。
- [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。
- Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。
- その他、細かいバグ修正や改善を行いました。
- **Version 0.10.1 (2026-02-13):**
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
- **Version 0.10.0 (2026-01-19):**
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。
- `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。
### サポートモデル

View File

@@ -7,23 +7,23 @@
<summary>Click to expand</summary>
- [Introduction](#introduction)
- [Supported Models](#supported-models)
- [Features](#features)
- [Sponsors](#sponsors)
- [Support the Project](#support-the-project)
- [Supported Models](#supported-models)
- [Features](#features)
- [Sponsors](#sponsors)
- [Support the Project](#support-the-project)
- [Documentation](#documentation)
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
- [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese)
- [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese)
- [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents)
- [Windows Installation](#windows-installation)
- [Windows Required Dependencies](#windows-required-dependencies)
- [Installation Steps](#installation-steps)
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
- [xformers installation (optional)](#xformers-installation-optional)
- [Windows Required Dependencies](#windows-required-dependencies)
- [Installation Steps](#installation-steps)
- [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch)
- [xformers installation (optional)](#xformers-installation-optional)
- [Linux/WSL2 Installation](#linuxwsl2-installation)
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
- [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only)
- [Upgrade](#upgrade)
- [Upgrade PyTorch](#upgrade-pytorch)
- [Upgrade PyTorch](#upgrade-pytorch)
- [Credits](#credits)
- [License](#license)
@@ -47,6 +47,19 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.
- `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296).
- It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details.
- LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details.
- Please refer to the [documentation](./docs/loha_lokr.md) for details.
- Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details.
- Thanks to woct0rdho for the contribution.
- Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details.
- Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues.
- Other minor bug fixes and improvements were made.
- **Version 0.10.1 (2026-02-13):**
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.

View File

@@ -286,7 +286,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
# Unpack text encoder conditions
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[
:4
] # ignore caption_dropout_rate which is not needed for training step
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
@@ -353,7 +355,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
# Add the caption dropout rates back to the list for validation dataset (which is re-used batch items)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates]
return super().process_batch(
batch,

617
cogview4_train_network.py Normal file
View File

@@ -0,0 +1,617 @@
import argparse
import copy
import math
import random
import os
import inspect
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from accelerate import Accelerator
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.cogview4.pipeline_cogview4 import CogView4Pipeline
from PIL import Image
from transformers import AutoTokenizer, GlmModel
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_models,
flux_train_utils,
flux_utils,
sd3_train_utils,
strategy_base,
strategy_cogview4,
train_util,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class CogView4NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# CogView4 specific argument validation
if hasattr(args, 'fp8_base_unet') and args.fp8_base_unet:
logger.warning("FP8 training is not yet fully supported for CogView4. Disabling fp8_base_unet.")
args.fp8_base_unet = False
if hasattr(args, 'cache_text_encoder_outputs') and args.cache_text_encoder_outputs:
logger.warning("Text encoder output caching is not yet implemented for CogView4. Disabling.")
args.cache_text_encoder_outputs = False
if hasattr(args, 'cache_text_encoder_outputs_to_disk') and args.cache_text_encoder_outputs_to_disk:
logger.warning("Text encoder output disk caching is not yet implemented for CogView4. Disabling.")
args.cache_text_encoder_outputs_to_disk = False
# Set default values for CogView4
if not hasattr(args, 'max_token_length'):
args.max_token_length = 128 # Default token length for GLM
if not hasattr(args, 'resolution'):
args.resolution = 256 # Default resolution for CogView4
# Update dataset resolution if needed
train_dataset_group.set_resolution(args.resolution)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
"""
Load the CogView4 model components including tokenizer, text encoder, VAE, and transformer.
"""
logger.info(f"Loading CogView4 model from {args.pretrained_model_name_or_path}")
# Load tokenizer and text encoder (GLM)
self.tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
use_fast=False,
trust_remote_code=True
)
self.text_encoder = GlmModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=weight_dtype,
trust_remote_code=True
)
self.text_encoder.eval()
# Load VAE
self.vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=weight_dtype,
trust_remote_code=True
)
self.vae.eval()
# Load transformer
self.transformer = CogView4Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=weight_dtype,
trust_remote_code=True
)
# Create noise scheduler
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
prediction_type="epsilon",
)
# Move models to device
device = accelerator.device
self.text_encoder = self.text_encoder.to(device)
self.vae = self.vae.to(device)
self.transformer = self.transformer.to(device)
# Set gradient checkpointing if enabled
if args.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()
# Text encoder gradient checkpointing is handled by prepare_text_encoder_grad_ckpt_workaround
# called by the base trainer if args.train_text_encoder is true for that TE.
# Store components for later use
self.weight_dtype = weight_dtype
# Return components in the expected format
return "cogview4-v1", [self.text_encoder], self.vae, self.transformer
def get_tokenize_strategy(self, args):
# For CogView4, we use a fixed token length for GLM
max_token_length = getattr(args, 'max_token_length', 128)
logger.info(f"Using max_token_length: {max_token_length} for GLM tokenizer")
return strategy_cogview4.CogView4TokenizeStrategy(max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy):
# For CogView4, we only have one tokenizer (GLM)
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_cogview4.CogView4LatentsCachingStrategy(
args.cache_latents_to_disk,
args.vae_batch_size,
skip_disk_cache_validity_check=False
)
def get_text_encoding_strategy(self, args):
# For CogView4, we use GLM instead of T5, but maintain similar interface
return strategy_cogview4.CogView4TextEncodingStrategy(
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
# For CogView4, we always return the text encoder (GLM) as it's needed for encoding
return text_encoders
def get_text_encoders_train_flags(self, args, text_encoders):
# For CogView4, we only have one text encoder (GLM)
return [getattr(args, 'train_text_encoder', False)]
def get_text_encoder_outputs_caching_strategy(self, args):
if getattr(args, 'cache_text_encoder_outputs', False):
return strategy_cogview4.CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=getattr(args, 'cache_text_encoder_outputs_to_disk', False),
batch_size=getattr(args, 'text_encoder_batch_size', 1),
skip_disk_cache_validity_check=getattr(args, 'skip_cache_check', False),
is_partial=getattr(args, 'train_text_encoder', False)
)
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
"""Cache text encoder outputs to speed up training.
Args:
args: Training arguments
accelerator: Accelerator instance
unet: UNet model
vae: VAE model
text_encoders: List containing the GLM text encoder
dataset: Dataset to cache text encoder outputs for
weight_dtype: Data type for weights
"""
if getattr(args, 'cache_text_encoder_outputs', False):
if not getattr(args, 'lowram', False):
# Free up GPU memory by moving models to CPU
logger.info("Moving VAE and UNet to CPU to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# Move text encoder to GPU with proper dtype
logger.info("Moving text encoder to GPU")
text_encoder = text_encoders[0] # CogView4 uses a single text encoder (GLM)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Cache text encoder outputs
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# Cache sample prompts if provided
if getattr(args, 'sample_prompts', None) is not None:
logger.info(f"Caching text encoder outputs for sample prompts: {args.sample_prompts}")
# Initialize CogView4 strategies
tokenize_strategy = strategy_cogview4.CogView4TokenizeStrategy(
max_length=getattr(args, 'max_token_length', 128),
tokenizer_cache_dir=getattr(args, 'tokenizer_cache_dir', None)
)
text_encoding_strategy = strategy_cogview4.CogView4TextEncodingStrategy(
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p and p not in sample_prompts_te_outputs: # Skip empty prompts and duplicates
logger.info(f"Caching text encoder outputs for prompt: {p}")
tokens = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy=tokenize_strategy,
models=text_encoders,
tokens=tokens,
apply_attention_mask=getattr(args, 'apply_attention_mask', True)
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# Move text encoder back to CPU if not training it
if not getattr(args, 'train_text_encoder', False):
logger.info("Moving text encoder back to CPU")
text_encoder.to("cpu")
clean_memory_on_device(accelerator.device)
# Move VAE and UNet back to their original devices if not in lowram mode
if not getattr(args, 'lowram', False):
logger.info("Moving VAE and UNet back to original devices")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Keep text encoder in GPU if we're not caching outputs
if text_encoders:
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, transformer):
"""
Generate sample images during training to monitor progress.
"""
logger.info(f"Generating sample images at step {global_step}")
# Set models to eval mode
was_training = transformer.training
transformer.eval()
vae.eval()
# Sample prompts to use for generation
sample_prompts = [
"A high quality photo of a cat",
"A beautiful landscape with mountains and a lake",
"A futuristic city at night"
]
# Generate images for each prompt
all_images = []
with torch.no_grad():
for prompt in sample_prompts:
# Tokenize the prompt
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_input.input_ids.to(device)
attention_mask = text_input.attention_mask.to(device)
# Get text embeddings
with torch.no_grad():
text_embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state
# Sample random noise
latents = torch.randn(
(1, 4, args.resolution // 8, args.resolution // 8),
device=device,
dtype=torch.float32
)
# Set the scheduler for inference
self.noise_scheduler.set_timesteps(50, device=device)
# Generate image using the denoising process
for t in self.noise_scheduler.timesteps:
# Expand the latents if we are doing classifier-free guidance
latent_model_input = torch.cat([latents] * 2) if args.guidance_scale > 1.0 else latents
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
with torch.no_grad():
noise_pred = transformer(
latent_model_input,
t.unsqueeze(0).repeat(latent_model_input.shape[0]),
encoder_hidden_states=torch.cat([text_embeddings] * 2) if args.guidance_scale > 1.0 else text_embeddings,
attention_mask=attention_mask,
return_dict=True,
).sample
# Perform guidance
if args.guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
# Scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
# Convert to PIL image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
all_images.extend(pil_images)
# Log images to tensorboard if available
if accelerator.is_main_process and hasattr(accelerator, "log"):
log_images = []
for i, img in enumerate(all_images):
# Convert PIL image to numpy for logging
log_images.append(np.array(img))
# Save individual images
os.makedirs(os.path.join(args.output_dir, "samples"), exist_ok=True)
img.save(os.path.join(args.output_dir, "samples", f"sample_epoch{epoch}_step{global_step}_{i}.png"))
# Log to tensorboard
accelerator.log({
"samples": [
wandb.Image(img, caption=f"{sample_prompts[i]}")
for i, img in enumerate(log_images)
]
}, step=global_step)
# Set models back to training mode if they were training before
if was_training:
transformer.train()
vae.train()
return all_images
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# Return the noise scheduler that was created during model loading
return self.noise_scheduler
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
transformer: CogView4Transformer2DModel,
network,
weight_dtype,
train_unet=True,
is_train=True,
):
"""
Get noise prediction and target for the loss computation.
"""
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get text embeddings
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
# Prepare text encoder outputs
with torch.set_grad_enabled(self.train_text_encoder):
text_embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state
# Predict the noise residual
with torch.set_grad_enabled(train_unet):
# Add conditioning dropout for classifier-free guidance
if args.guidance_scale > 1.0:
# Randomly drop text conditioning 5% of the time
mask = (torch.rand(bsz, device=latents.device) < 0.05).float().unsqueeze(1).unsqueeze(1)
text_embeddings = text_embeddings * (1 - mask) + torch.zeros_like(text_embeddings) * mask
# Predict noise
model_pred = transformer(
noisy_latents,
timesteps,
encoder_hidden_states=text_embeddings,
attention_mask=attention_mask,
return_dict=True,
).sample
# Calculate target
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# For classifier-free guidance, we need to do two forward passes
if args.guidance_scale > 1.0:
# Predict the conditional and unconditional outputs
model_pred_uncond = transformer(
noisy_latents,
timesteps,
encoder_hidden_states=torch.zeros_like(text_embeddings),
attention_mask=attention_mask,
return_dict=True,
).sample
# Perform classifier-free guidance
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
# For training, we only compute the loss on the conditional prediction
if is_train:
model_pred = model_pred_uncond + args.guidance_scale * (model_pred - model_pred_uncond)
# Simple weighting - can be adjusted based on timestep if needed
weighting = torch.ones_like(timesteps, dtype=weight_dtype, device=latents.device)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
"""
Post-process the loss value.
This can include applying timestep weighting, gradient clipping, etc.
"""
# Apply timestep weighting if specified
if hasattr(args, 'timestep_bias_portion') and args.timestep_bias_portion > 0.0:
# Simple timestep weighting - can be made more sophisticated if needed
weights = torch.ones_like(timesteps, dtype=torch.float32)
if hasattr(args, 'timestep_bias_begin') and args.timestep_bias_begin > 0:
mask = timesteps < args.timestep_bias_begin
weights[mask] = 0.0
if hasattr(args, 'timestep_bias_end') and args.timestep_bias_end < 1000:
mask = timesteps > args.timestep_bias_end
weights[mask] = 0.0
if hasattr(args, 'timestep_bias_multiplier') and args.timestep_bias_multiplier != 1.0:
weights = weights * args.timestep_bias_multiplier
loss = loss * weights.to(loss.device)
# Clip loss values if specified
if hasattr(args, 'clip_grad_norm') and args.clip_grad_norm > 0.0:
loss = torch.clamp(loss, -args.clip_grad_norm, args.clip_grad_norm)
return loss
def prepare_extra_step_kwargs(self, generator, eta):
"""
Prepare extra kwargs for the scheduler step, such as the generator for reproducibility.
"""
# Prepare extra step kwargs.
# TODO: Logic should ideally just be moved to base class
accepts_eta = "eta" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# Check if the scheduler accepts a generator
accepts_generator = "generator" in set(inspect.signature(self.noise_scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def update_metadata(self, metadata, args):
metadata["ss_apply_attn_mask"] = args.apply_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
"""Check if text encoder outputs are cached and not being trained."""
return getattr(args, 'cache_text_encoder_outputs', False) and not getattr(args, 'train_text_encoder', False)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
"""Prepare text encoder for gradient checkpointing.
For CogView4, we only have one text encoder (GLM) so we don't need index-based handling.
The base class method handles enabling gradient checkpointing if args specify it.
"""
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
"""Prepare text encoder for FP8 training.
Args:
index: Text encoder index (always 0 for CogView4)
text_encoder: The text encoder model (GLM)
te_weight_dtype: Target weight dtype for the encoder
weight_dtype: Base weight dtype for embeddings
"""
if index != 0:
logger.warning(f"Unexpected text encoder index {index} for CogView4, expecting 0.")
# Still proceed, assuming it's the single GLM encoder
logger.info(f"Preparing GLM text encoder (index {index}) for {te_weight_dtype}, embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype)
# Move embeddings to base weight dtype if they exist
if hasattr(text_encoder, 'word_embeddings'): # GLM typically has word_embeddings
text_encoder.word_embeddings.to(dtype=weight_dtype)
if hasattr(text_encoder, 'position_embeddings'): # GLM might have position_embeddings
text_encoder.position_embeddings.to(dtype=weight_dtype)
# Add other relevant parts of GLM if they need specific dtype handling for FP8
return text_encoder
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
"""Called at the end of each validation step."""
# No special handling needed for CogView4 (e.g., no block swapping)
pass
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
"""Prepare UNet model (CogView4Transformer2DModel) with accelerator.
For CogView4, we use standard model preparation.
"""
return super().prepare_unet_with_accelerator(args, accelerator, unet)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = CogView4NetworkTrainer()
trainer.train(args)

View File

@@ -122,11 +122,15 @@ These are options related to the configuration of the data set. They cannot be d
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
* `skip_image_resolution`
* Images whose original resolution (area) is equal to or smaller than the specified resolution will be skipped. Specify as `'size'` or `[width, height]`. This corresponds to the command-line argument `--skip_image_resolution`.
* Useful when sharing the same image directory across multiple datasets with different resolutions, to exclude low-resolution source images from higher-resolution datasets.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
@@ -254,6 +258,34 @@ resolution = 768
image_dir = 'C:\hoge'
```
When using multi-resolution datasets, you can use `skip_image_resolution` to exclude images whose original size is too small for higher-resolution datasets. This prevents overlapping of low-resolution images across datasets and improves training quality. This option can also be used to simply exclude low-resolution source images from datasets.
```toml
[general]
enable_bucket = true
bucket_no_upscale = true
max_bucket_reso = 1536
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1024
skip_image_resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1280
skip_image_resolution = 1024
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
In this example, the 1024-resolution dataset skips images whose original size is 768x768 or smaller, and the 1280-resolution dataset skips images whose original size is 1024x1024 or smaller.
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
@@ -284,6 +316,7 @@ For the command line options listed below, if an option is specified in both the
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--skip_image_resolution` | |
| `--train_batch_size` | `batch_size` |
## Error Guide

View File

@@ -115,11 +115,15 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
| `skip_image_resolution` | `768`, `[512, 768]` | o | o |
* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
* `skip_image_resolution`
* 指定した解像度(面積)以下の画像をスキップします。`'サイズ'` または `[幅, 高さ]` で指定します。コマンドライン引数の `--skip_image_resolution` と同等です。
* 同じ画像ディレクトリを異なる解像度の複数のデータセットで使い回す場合に、低解像度の元画像を高解像度のデータセットから除外するために使用します。
これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
@@ -259,6 +263,34 @@ resolution = 768
image_dir = 'C:\hoge'
```
なお、マルチ解像度データセットでは `skip_image_resolution` を使用して、元の画像サイズが小さい画像を高解像度データセットから除外できます。これにより、低解像度画像のデータセット間での重複を防ぎ、学習品質を向上させることができます。また、小さい画像を除外するフィルターとしても機能します。
```toml
[general]
enable_bucket = true
bucket_no_upscale = true
max_bucket_reso = 1536
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1024
skip_image_resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 1280
skip_image_resolution = 1024
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
この例では、1024 解像度のデータセットでは元の画像サイズが 768x768 以下の画像がスキップされ、1280 解像度のデータセットでは 1024x1024 以下の画像がスキップされます。
## コマンドライン引数との併用
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
@@ -289,6 +321,7 @@ resolution = 768
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--skip_image_resolution` | |
| `--train_batch_size` | `batch_size` |
## エラーの手引き

736
docs/train_leco.md Normal file
View File

@@ -0,0 +1,736 @@
# LECO Training Guide / LECO 学習ガイド
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
This repository provides two LECO training scripts:
- `train_leco.py` for Stable Diffusion 1.x / 2.x
- `sdxl_train_leco.py` for SDXL
<details>
<summary>日本語</summary>
LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
- `sdxl_train_leco.py` : SDXL 用
</details>
## 1. Overview / 概要
### What LECO Can Do / LECO でできること
LECO can be used for:
- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
<details>
<summary>日本語</summary>
LECO は以下の用途に使用できます:
- **概念の消去**: 特定のスタイルや概念を除去する生成画像から「van gogh」スタイルを消去
- **概念の強化**: 特定の属性を強化する「detailed」をより顕著にする
- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する「short hair」と「long hair」の間のスライダー
通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
</details>
### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
| | Standard LoRA | LECO |
|---|---|---|
| Training data | Image dataset required | **No images needed** |
| Configuration | Dataset TOML | Prompt TOML |
| Training target | U-Net and/or Text Encoder | **U-Net only** |
| Training unit | Epochs and steps | **Steps only** |
| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
<details>
<summary>日本語</summary>
| | 通常の LoRA | LECO |
|---|---|---|
| 学習データ | 画像データセットが必要 | **画像不要** |
| 設定ファイル | データセット TOML | プロンプト TOML |
| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
| 学習単位 | エポックとステップ | **ステップのみ** |
| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
</details>
## 2. Prompt Configuration File / プロンプト設定ファイル
LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
<details>
<summary>日本語</summary>
LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**ai-toolkit スタイルの2つの形式に対応しています。
</details>
### 2.1. Original LECO Format / オリジナル LECO 形式
Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
```toml
[[prompts]]
target = "van gogh"
positive = "van gogh"
unconditional = ""
neutral = ""
action = "erase"
guidance_scale = 1.0
resolution = 512
batch_size = 1
multiplier = 1.0
weight = 1.0
```
Each `[[prompts]]` entry defines one training pair with the following fields:
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `target` | Yes | - | The concept to be modified by the LoRA |
| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
| `unconditional` | No | `""` | The unconditional/negative prompt |
| `neutral` | No | `""` | The neutral baseline prompt |
| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
| `weight` | No | `1.0` | Loss weight for this prompt pair |
<details>
<summary>日本語</summary>
`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
`[[prompts]]` エントリのフィールド:
| フィールド | 必須 | デフォルト | 説明 |
|-----------|------|-----------|------|
| `target` | はい | - | LoRA によって変更される概念 |
| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]` |
| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
</details>
### 2.2. Slider Target Format / スライダーターゲット形式
Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
```toml
guidance_scale = 1.0
resolution = 1024
neutral = ""
[[targets]]
target_class = "1girl"
positive = "1girl, long hair"
negative = "1girl, short hair"
multiplier = 1.0
weight = 1.0
```
Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
Each `[[targets]]` entry supports the following fields:
| Field | Required | Default | Description |
|-------|----------|---------|-------------|
| `target_class` | Yes | - | The base class/subject prompt |
| `positive` | No* | `""` | Prompt for the positive direction of the slider |
| `negative` | No* | `""` | Prompt for the negative direction of the slider |
| `multiplier` | No | `1.0` | LoRA strength multiplier |
| `weight` | No | `1.0` | Loss weight |
\* At least one of `positive` or `negative` must be provided.
<details>
<summary>日本語</summary>
`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive``negative` の両方がある場合は4ペア、片方のみの場合は2ペア
トップレベルのフィールド(`guidance_scale``resolution``neutral``batch_size` など)は全ターゲットのデフォルト値として機能します。
`[[targets]]` エントリのフィールド:
| フィールド | 必須 | デフォルト | 説明 |
|-----------|------|-----------|------|
| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
| `weight` | いいえ | `1.0` | loss 重み |
\* `positive``negative` のうち少なくとも一方を指定する必要があります。
</details>
### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
```toml
guidance_scale = 1.5
resolution = 1024
neutrals = ["", "photo of a person", "cinematic portrait"]
[[targets]]
target_class = "person"
positive = "smiling person"
negative = "expressionless person"
```
You can also load neutral prompts from a text file (one prompt per line):
```toml
neutral_prompt_file = "neutrals.txt"
[[targets]]
target_class = ""
positive = "high detail"
negative = "low detail"
```
<details>
<summary>日本語</summary>
スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
ニュートラルプロンプトをテキストファイル1行1プロンプトから読み込むこともできます。
</details>
### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
<details>
<summary>日本語</summary>
既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
</details>
**YAML:**
```yaml
targets:
- target_class: ""
positive: "high detail"
negative: "low detail"
multiplier: 1.0
guidance_scale: 1.0
resolution: 512
```
**TOML:**
```toml
guidance_scale = 1.0
resolution = 512
[[targets]]
target_class = ""
positive = "high detail"
negative = "low detail"
multiplier = 1.0
```
Key syntax differences:
- Use `=` instead of `:` for key-value pairs
- Use `[[targets]]` header instead of `targets:` with `- ` list items
- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
<details>
<summary>日本語</summary>
主な構文の違い:
- キーと値の区切りに `:` ではなく `=` を使用
- `targets:``- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`
</details>
## 3. Running the Training / 学習の実行
Training is started by executing the script from the terminal. Below are basic command-line examples.
In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
<details>
<summary>日本語</summary>
学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
</details>
### SD 1.x / 2.x
```bash
accelerate launch --mixed_precision bf16 train_leco.py
--pretrained_model_name_or_path="model.safetensors"
--prompts_file="prompts.toml"
--output_dir="output"
--output_name="my_leco"
--network_dim=8
--network_alpha=4
--learning_rate=1e-4
--optimizer_type="AdamW8bit"
--max_train_steps=500
--max_denoising_steps=40
--mixed_precision=bf16
--sdpa
--gradient_checkpointing
--save_every_n_steps=100
```
### SDXL
```bash
accelerate launch --mixed_precision bf16 sdxl_train_leco.py
--pretrained_model_name_or_path="sdxl_model.safetensors"
--prompts_file="slider.toml"
--output_dir="output"
--output_name="my_sdxl_slider"
--network_dim=8
--network_alpha=4
--learning_rate=1e-4
--optimizer_type="AdamW8bit"
--max_train_steps=1000
--max_denoising_steps=40
--mixed_precision=bf16
--sdpa
--gradient_checkpointing
--save_every_n_steps=200
```
## 4. Command-Line Arguments / コマンドライン引数
### 4.1. LECO-Specific Arguments / LECO 固有の引数
These arguments are unique to LECO and not found in standard LoRA training scripts.
<details>
<summary>日本語</summary>
以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
</details>
* `--prompts_file="prompts.toml"` **[Required]**
* Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
* `--max_denoising_steps=40`
* Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
* `--leco_denoise_guidance_scale=3.0`
* Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
<details>
<summary>日本語</summary>
* `--prompts_file="prompts.toml"` **[必須]**
* LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
* `--max_denoising_steps=40`
* 各学習イテレーションでの部分デイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデイズが行われます。デフォルト: `40`
* `--leco_denoise_guidance_scale=3.0`
* 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`
</details>
#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
There are two separate guidance scale parameters that control different aspects of LECO training:
1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
<details>
<summary>日本語</summary>
LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
2. **`guidance_scale`TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
</details>
### 4.2. Model Arguments / モデル引数
* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
* Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
* `--v2` (SD 1.x/2.x only)
* Specify when using a Stable Diffusion v2.x model.
* `--v_parameterization` (SD 1.x/2.x only)
* Specify when using a v-prediction model (e.g., SD 2.x 768px models).
<details>
<summary>日本語</summary>
* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
* ベースとなる Stable Diffusion モデルのパス(`.ckpt``.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID
* `--v2`SD 1.x/2.x のみ)
* Stable Diffusion v2.x モデルを使用する場合に指定します。
* `--v_parameterization`SD 1.x/2.x のみ)
* v-prediction モデルSD 2.x 768px モデルなど)を使用する場合に指定します。
</details>
### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
* `--network_module=networks.lora`
* Network module to train. Default: `networks.lora`.
* `--network_dim=8`
* LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
* `--network_alpha=4`
* LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
* `--network_dropout=0.1`
* Dropout rate for LoRA layers. Optional.
* `--network_args "key=value" ...`
* Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
* `--network_weights="path/to/weights.safetensors"`
* Load pretrained LoRA weights to continue training.
* `--dim_from_weights`
* Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
<details>
<summary>日本語</summary>
* `--network_module=networks.lora`
* 学習するネットワークモジュール。デフォルト: `networks.lora`
* `--network_dim=8`
* LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`
* `--network_alpha=4`
* 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`
* `--network_dropout=0.1`
* LoRA レイヤーのドロップアウト率。省略可。
* `--network_args "key=value" ...`
* ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
* `--network_weights="path/to/weights.safetensors"`
* 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
* `--dim_from_weights`
* `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
</details>
### 4.4. Training Parameters / 学習パラメータ
* `--max_train_steps=500`
* Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
* Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
* `--learning_rate=1e-4`
* Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
* `--unet_lr=1e-4`
* Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
* `--optimizer_type="AdamW8bit"`
* Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
* `--lr_scheduler="constant"`
* Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
* `--lr_warmup_steps=0`
* Number of warmup steps for the learning rate scheduler.
* `--gradient_accumulation_steps=1`
* Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
* `--max_grad_norm=1.0`
* Maximum gradient norm for gradient clipping. Set to `0` to disable.
* `--min_snr_gamma=5.0`
* Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
<details>
<summary>日本語</summary>
* `--max_train_steps=500`
* 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`
* 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
* `--learning_rate=1e-4`
* 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`
* `--unet_lr=1e-4`
* U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
* `--optimizer_type="AdamW8bit"`
* オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW``Lion``Adafactor` 等が選択可能です。
* `--lr_scheduler="constant"`
* 学習率スケジューラ。`constant``cosine``linear``constant_with_warmup` 等が選択可能です。
* `--lr_warmup_steps=0`
* 学習率スケジューラのウォームアップステップ数。
* `--gradient_accumulation_steps=1`
* 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
* `--max_grad_norm=1.0`
* 勾配クリッピングの最大勾配ノルム。`0` で無効化。
* `--min_snr_gamma=5.0`
* Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
</details>
### 4.5. Output and Save Arguments / 出力・保存引数
* `--output_dir="output"` **[Required]**
* Directory for saving trained LoRA models and logs.
* `--output_name="my_leco"` **[Required]**
* Base filename for the trained LoRA (without extension).
* `--save_model_as="safetensors"`
* Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
* `--save_every_n_steps=100`
* Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
* Note: `--save_every_n_epochs` is **not supported** for LECO.
* `--save_precision="fp16"`
* Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
* `--no_metadata`
* Do not write metadata into the saved model file.
* `--training_comment="my comment"`
* A comment string stored in the model metadata.
<details>
<summary>日本語</summary>
* `--output_dir="output"` **[必須]**
* 学習済み LoRA モデルとログの保存先ディレクトリ。
* `--output_name="my_leco"` **[必須]**
* 学習済み LoRA のベースファイル名(拡張子なし)。
* `--save_model_as="safetensors"`
* モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt``pt` から選択。
* `--save_every_n_steps=100`
* N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
* 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
* `--save_precision="fp16"`
* モデル保存時の精度。`float``fp16``bf16` から選択。省略時は学習時の精度が使用されます。
* `--no_metadata`
* 保存するモデルファイルにメタデータを書き込みません。
* `--training_comment="my comment"`
* モデルのメタデータに保存されるコメント文字列。
</details>
### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
* `--mixed_precision="bf16"`
* Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
* `--full_fp16`
* Train entirely in fp16 precision including gradients.
* `--full_bf16`
* Train entirely in bf16 precision including gradients.
* `--gradient_checkpointing`
* Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
* `--sdpa`
* Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
* `--xformers`
* Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
* `--mem_eff_attn`
* Use memory-efficient attention implementation. Another alternative to `--sdpa`.
<details>
<summary>日本語</summary>
* `--mixed_precision="bf16"`
* 混合精度学習。`no``fp16``bf16` から選択。`bf16` または `fp16` の使用を推奨します。
* `--full_fp16`
* 勾配を含め全体を fp16 精度で学習します。
* `--full_bf16`
* 勾配を含め全体を bf16 精度で学習します。
* `--gradient_checkpointing`
* gradient checkpointing を有効にしてVRAM使用量を削減します学習速度は若干低下。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
* `--sdpa`
* Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
* `--xformers`
* xformers を使用したメモリ効率の良い attention`xformers` パッケージが必要)。`--sdpa` の代替。
* `--mem_eff_attn`
* メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
</details>
### 4.7. Other Useful Arguments / その他の便利な引数
* `--seed=42`
* Random seed for reproducibility. If not specified, a random seed is automatically generated.
* `--noise_offset=0.05`
* Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
* `--zero_terminal_snr`
* Fix noise scheduler betas to enforce zero terminal SNR.
* `--clip_skip=2` (SD 1.x/2.x only)
* Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
* `--logging_dir="logs"`
* Directory for TensorBoard logs. Enables logging when specified.
* `--log_with="tensorboard"`
* Logging tool. Options: `tensorboard`, `wandb`, `all`.
<details>
<summary>日本語</summary>
* `--seed=42`
* 再現性のための乱数シード。指定しない場合は自動生成されます。
* `--noise_offset=0.05`
* ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
* `--zero_terminal_snr`
* noise scheduler の betas を修正してゼロ終端 SNR を強制します。
* `--clip_skip=2`SD 1.x/2.x のみ)
* text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`
* `--logging_dir="logs"`
* TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
* `--log_with="tensorboard"`
* ログツール。`tensorboard``wandb``all` から選択。
</details>
## 5. Tips / ヒント
### Tuning the Effect Strength / 効果の強さの調整
If the trained LoRA has a weak or unnoticeable effect:
1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
4. **Increase `--max_train_steps`** to train longer.
5. **Apply the LoRA with a higher weight** at inference time.
<details>
<summary>日本語</summary>
学習した LoRA の効果が弱い、または認識できない場合:
1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
4. **`--max_train_steps` を増やして**、より長く学習する。
5. **推論時に LoRA のウェイトを大きくして**適用する。
</details>
### Recommended Starting Settings / 推奨の開始設定
| Parameter | SD 1.x/2.x | SDXL |
|-----------|-------------|------|
| `--network_dim` | `4`-`8` | `8`-`16` |
| `--learning_rate` | `1e-4` | `1e-4` |
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
| `resolution` (in TOML) | `512` | `1024` |
| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
<details>
<summary>日本語</summary>
| パラメータ | SD 1.x/2.x | SDXL |
|-----------|-------------|------|
| `--network_dim` | `4`-`8` | `8`-`16` |
| `--learning_rate` | `1e-4` | `1e-4` |
| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
| `resolution`TOML内 | `512` | `1024` |
| `guidance_scale`TOML内 | `1.0`-`2.0` | `1.0`-`3.0` |
| `batch_size`TOML内 | `1`-`4` | `1`-`4` |
</details>
### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップSDXL
For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
```toml
resolution = 1024
dynamic_resolution = true
dynamic_crops = true
[[targets]]
target_class = ""
positive = "high detail"
negative = "low detail"
```
- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
These options can improve the LoRA's generalization across different aspect ratios.
<details>
<summary>日本語</summary>
SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
</details>
## 6. Using the Trained Model / 学習済みモデルの利用
The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
<details>
<summary>日本語</summary>
学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
</details>

View File

@@ -739,13 +739,16 @@ class FinalLayer(nn.Module):
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
2, dim=-1
)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d")
scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d")
@@ -864,32 +867,34 @@ class Block(nn.Module):
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if x_B_T_H_W_D.dtype == torch.float16:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
# Compute AdaLN modulation parameters
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
3, dim=-1
)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
3, dim=-1
)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = (
self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
3, dim=-1
)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1)
# Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D)
shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")

View File

@@ -108,6 +108,7 @@ class BaseDatasetParams:
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
skip_image_resolution: Optional[Tuple[int, int]] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -118,7 +119,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -244,6 +245,7 @@ class ConfigSanitizer:
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
"skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}
# options handled by argparse but not handled by user config
@@ -256,6 +258,7 @@ class ConfigSanitizer:
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
"skip_image_resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
@@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
skip_image_resolution: {dataset.skip_image_resolution}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")

522
library/leco_train_util.py Normal file
View File

@@ -0,0 +1,522 @@
import argparse
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import toml
from torch.utils.checkpoint import checkpoint
from library import train_util
import logging
logger = logging.getLogger(__name__)
def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
kwargs = {}
if args.network_args:
for net_arg in args.network_args:
key, value = net_arg.split("=", 1)
kwargs[key] = value
if "dropout" not in kwargs:
kwargs["dropout"] = args.network_dropout
return kwargs
def get_save_extension(args: argparse.Namespace) -> str:
if args.save_model_as == "ckpt":
return ".ckpt"
if args.save_model_as == "pt":
return ".pt"
return ".safetensors"
def save_weights(
accelerator,
network,
args: argparse.Namespace,
save_dtype,
prompt_settings,
global_step: int,
last: bool = False,
extra_metadata: Optional[Dict[str, str]] = None,
) -> None:
os.makedirs(args.output_dir, exist_ok=True)
ext = get_save_extension(args)
ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata = None
if not args.no_metadata:
metadata = {
"ss_network_module": args.network_module,
"ss_network_dim": str(args.network_dim),
"ss_network_alpha": str(args.network_alpha),
"ss_leco_prompt_count": str(len(prompt_settings)),
"ss_leco_prompts_file": os.path.basename(args.prompts_file),
}
if extra_metadata:
metadata.update(extra_metadata)
if args.training_comment:
metadata["ss_training_comment"] = args.training_comment
metadata["ss_leco_preview"] = json.dumps(
[
{
"target": p.target,
"positive": p.positive,
"unconditional": p.unconditional,
"neutral": p.neutral,
"action": p.action,
"multiplier": p.multiplier,
"weight": p.weight,
}
for p in prompt_settings[:16]
],
ensure_ascii=False,
)
unwrapped = accelerator.unwrap_model(network)
unwrapped.save_weights(ckpt_file, save_dtype, metadata)
logger.info(f"saved model to: {ckpt_file}")
ResolutionValue = Union[int, Tuple[int, int]]
@dataclass
class PromptEmbedsXL:
text_embeds: torch.Tensor
pooled_embeds: torch.Tensor
class PromptEmbedsCache:
def __init__(self):
self.prompts: dict[str, Any] = {}
def __setitem__(self, name: str, value: Any) -> None:
self.prompts[name] = value
def __getitem__(self, name: str) -> Any:
return self.prompts[name]
@dataclass
class PromptSettings:
target: str
positive: Optional[str] = None
unconditional: str = ""
neutral: Optional[str] = None
action: str = "erase"
guidance_scale: float = 1.0
resolution: ResolutionValue = 512
dynamic_resolution: bool = False
batch_size: int = 1
dynamic_crops: bool = False
multiplier: float = 1.0
weight: float = 1.0
def __post_init__(self):
if self.positive is None:
self.positive = self.target
if self.neutral is None:
self.neutral = self.unconditional
if self.action not in ("erase", "enhance"):
raise ValueError(f"Invalid action: {self.action}")
self.guidance_scale = float(self.guidance_scale)
self.batch_size = int(self.batch_size)
self.multiplier = float(self.multiplier)
self.weight = float(self.weight)
self.dynamic_resolution = bool(self.dynamic_resolution)
self.dynamic_crops = bool(self.dynamic_crops)
self.resolution = normalize_resolution(self.resolution)
def get_resolution(self) -> Tuple[int, int]:
if isinstance(self.resolution, tuple):
return self.resolution
return (self.resolution, self.resolution)
def build_target(self, positive_latents, neutral_latents, unconditional_latents):
offset = self.guidance_scale * (positive_latents - unconditional_latents)
if self.action == "erase":
return neutral_latents - offset
return neutral_latents + offset
def normalize_resolution(value: Any) -> ResolutionValue:
if isinstance(value, tuple):
if len(value) != 2:
raise ValueError(f"resolution tuple must have 2 items: {value}")
return (int(value[0]), int(value[1]))
if isinstance(value, list):
if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
return (int(value[0]), int(value[1]))
raise ValueError(f"resolution list must have 2 numeric items: {value}")
return int(value)
def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
with open(path, "r", encoding="utf-8") as f:
return [line.strip() for line in f.readlines() if line.strip()]
def _recognized_prompt_keys() -> set[str]:
return {
"target",
"positive",
"unconditional",
"neutral",
"action",
"guidance_scale",
"resolution",
"dynamic_resolution",
"batch_size",
"dynamic_crops",
"multiplier",
"weight",
}
def _recognized_slider_keys() -> set[str]:
return {
"target_class",
"positive",
"negative",
"neutral",
"guidance_scale",
"resolution",
"resolutions",
"dynamic_resolution",
"batch_size",
"dynamic_crops",
"multiplier",
"weight",
}
def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
merged = {k: v for k, v in defaults.items() if k in known_keys}
merged.update(item)
return merged
def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
if value is None:
return [512]
if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
return [normalize_resolution(v) for v in value]
return [normalize_resolution(value)]
def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
target_class = str(target.get("target_class", ""))
positive = str(target.get("positive", "") or "")
negative = str(target.get("negative", "") or "")
multiplier = target.get("multiplier", 1.0)
resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
if not positive.strip() and not negative.strip():
raise ValueError("slider target requires either positive or negative prompt")
base = dict(
target=target_class,
neutral=neutral,
guidance_scale=target.get("guidance_scale", 1.0),
dynamic_resolution=target.get("dynamic_resolution", False),
batch_size=target.get("batch_size", 1),
dynamic_crops=target.get("dynamic_crops", False),
weight=target.get("weight", 1.0),
)
# Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
# With both positive and negative: 4 pairs; with only one: 2 pairs.
pairs: list[tuple[str, str, str, float]] = []
if positive.strip() and negative.strip():
pairs = [
(negative, positive, "erase", multiplier),
(positive, negative, "enhance", multiplier),
(positive, negative, "erase", -multiplier),
(negative, positive, "enhance", -multiplier),
]
elif negative.strip():
pairs = [
(negative, "", "erase", multiplier),
(negative, "", "enhance", -multiplier),
]
else:
pairs = [
(positive, "", "enhance", multiplier),
(positive, "", "erase", -multiplier),
]
prompt_settings: List[PromptSettings] = []
for resolution in resolutions:
for pos, uncond, action, mult in pairs:
prompt_settings.append(
PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
)
return prompt_settings
def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
path = Path(path)
with open(path, "r", encoding="utf-8") as f:
data = toml.load(f)
if not data:
raise ValueError("prompt file is empty")
default_prompt_values = {
"guidance_scale": 1.0,
"resolution": 512,
"dynamic_resolution": False,
"batch_size": 1,
"dynamic_crops": False,
"multiplier": 1.0,
"weight": 1.0,
}
prompt_settings: List[PromptSettings] = []
def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
prompt_settings.append(PromptSettings(**merged))
def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
if not neutral_values:
neutral_values = [str(merged.get("neutral", "") or "")]
for neutral in neutral_values:
prompt_settings.extend(_expand_slider_target(merged, neutral))
if "prompts" in data:
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
for item in data["prompts"]:
if "target_class" in item:
append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
else:
append_prompt_item(item, defaults)
else:
slider_config = data.get("slider", data)
targets = slider_config.get("targets")
if targets is None:
if "target_class" in slider_config:
targets = [slider_config]
elif "target" in slider_config:
targets = [slider_config]
else:
raise ValueError("prompt file does not contain prompts or slider targets")
if len(targets) == 0:
raise ValueError("prompt file contains an empty targets list")
if "target" in targets[0]:
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
for item in targets:
append_prompt_item(item, defaults)
else:
defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
neutral_values: List[str] = []
if "neutrals" in slider_config:
neutral_values.extend(str(v) for v in slider_config["neutrals"])
if "neutral_prompt_file" in slider_config:
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
if "prompt_file" in slider_config:
neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
if not neutral_values:
neutral_values = [str(slider_config.get("neutral", "") or "")]
for item in targets:
item_neutrals = neutral_values
if "neutrals" in item:
item_neutrals = [str(v) for v in item["neutrals"]]
elif "neutral_prompt_file" in item:
item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
elif "prompt_file" in item:
item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
elif "neutral" in item:
item_neutrals = [str(item["neutral"] or "")]
append_slider_item(item, defaults, item_neutrals)
if not prompt_settings:
raise ValueError("no prompt settings found")
return prompt_settings
def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
tokens = tokenize_strategy.tokenize(prompt)
return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
tokens = tokenize_strategy.tokenize(prompt)
hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
if noise_offset is None:
return latents
noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
noise = noise.to(dtype=latents.dtype, device=latents.device)
return latents + noise_offset * noise
def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
noise = torch.randn(
(batch_size, 4, height // 8, width // 8),
device="cpu",
).repeat(n_prompts, 1, 1, 1)
return noise * scheduler.init_noise_sigma
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
"""Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
def _run_with_checkpoint(function, *args):
if torch.is_grad_enabled():
return checkpoint(function, *args, use_reentrant=False)
return function(*args)
def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
def run_unet(model_input, encoder_hidden_states):
return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion(
unet,
scheduler,
latents: torch.Tensor,
text_embeddings: torch.Tensor,
total_timesteps: int,
start_timesteps: int = 0,
guidance_scale: float = 3.0,
):
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
def get_add_time_ids(
height: int,
width: int,
dynamic_crops: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if dynamic_crops:
random_scale = torch.rand(1).item() * 2 + 1
original_size = (int(height * random_scale), int(width * random_scale))
crops_coords_top_left = (
torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
)
target_size = (height, width)
else:
original_size = (height, width)
crops_coords_top_left = (0, 0)
target_size = (height, width)
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
if device is not None:
add_time_ids = add_time_ids.to(device)
return add_time_ids
def predict_noise_xl(
unet,
scheduler,
timestep,
latents: torch.Tensor,
prompt_embeds: PromptEmbedsXL,
add_time_ids: torch.Tensor,
guidance_scale: float = 1.0,
):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
orig_size = add_time_ids[:, :2]
crop_size = add_time_ids[:, 2:4]
target_size = add_time_ids[:, 4:6]
from library import sdxl_train_util
size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
def run_unet(model_input, text_embeds, vector_embeds):
return unet(model_input, timestep, text_embeds, vector_embeds)
noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion_xl(
unet,
scheduler,
latents: torch.Tensor,
prompt_embeds: PromptEmbedsXL,
add_time_ids: torch.Tensor,
total_timesteps: int,
start_timesteps: int = 0,
guidance_scale: float = 3.0,
):
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
noise_pred = predict_noise_xl(
unet,
scheduler,
timestep,
latents,
prompt_embeds=prompt_embeds,
add_time_ids=add_time_ids,
guidance_scale=guidance_scale,
)
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
return latents
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
max_resolution = bucket_resolution
min_resolution = bucket_resolution // 2
step = 64
min_step = min_resolution // step
max_step = max_resolution // step
height = torch.randint(min_step, max_step + 1, (1,)).item() * step
width = torch.randint(min_step, max_step + 1, (1,)).item() * step
return height, width
def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
height, width = prompt.get_resolution()
if prompt.dynamic_resolution and height == width:
return get_random_resolution_in_bucket(height)
return height, width

View File

@@ -0,0 +1,379 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import AutoTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
GLM_TOKENIZER_ID = "THUDM/CogView4-6B"
class CogView4TokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.max_length = max_length
self.tokenizer = self._load_tokenizer(AutoTokenizer, GLM_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
# Add special tokens if needed
self.tokenizer.pad_token = self.tokenizer.eos_token
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with GLM tokenizer
tokens = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
return [input_ids, attention_mask]
class CogView4TextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_attention_mask: bool = True) -> None:
"""
Args:
apply_attention_mask: Whether to apply attention mask during encoding.
"""
self.apply_attention_mask = apply_attention_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_attention_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_attention_mask is None:
apply_attention_mask = self.apply_attention_mask
# Get GLM model (should be the only model in the list)
glm_model = models[0]
input_ids = tokens[0]
attention_mask = tokens[1] if len(tokens) > 1 else None
# Move tensors to the correct device
device = glm_model.device
input_ids = input_ids.to(device)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Get GLM model outputs
with torch.no_grad():
outputs = glm_model(
input_ids=input_ids,
attention_mask=attention_mask if apply_attention_mask else None,
output_hidden_states=True,
return_dict=True
)
# Get the last hidden state
hidden_states = outputs.hidden_states[-1] # [batch_size, seq_len, hidden_size]
# For compatibility with existing code, we'll return a list similar to the original
# but with GLM's hidden states instead of CLIP/T5 outputs
return [
hidden_states, # Replaces l_pooled
hidden_states, # Replaces t5_out (same tensor for now, can be modified if needed)
torch.zeros(hidden_states.shape[0], hidden_states.shape[1], 3, device=device), # txt_ids placeholder
attention_mask # attention mask
]
class CogView4TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_cogview4_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_attention_mask: bool = True,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_attention_mask = apply_attention_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + CogView4TextEncoderOutputsCachingStrategy.COGVIEW4_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
required_fields = ["hidden_states", "attention_mask", "apply_attention_mask"]
for field in required_fields:
if field not in npz:
return False
npz_apply_attention_mask = bool(npz["apply_attention_mask"])
if npz_apply_attention_mask != self.apply_attention_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
logger.exception(e)
return False
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_states = data["hidden_states"]
attention_mask = data["attention_mask"]
return [
hidden_states, # l_pooled replacement
hidden_states, # t5_out replacement
np.zeros((hidden_states.shape[0], hidden_states.shape[1], 3), dtype=np.float32), # txt_ids
attention_mask # attention mask
]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
model_dtype = next(models[0].parameters()).dtype
if model_dtype == torch.float8_e4m3fn or model_dtype == torch.float8_e5m2:
logger.warning(
"Model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_states, _, _, attention_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
if hidden_states.dtype == torch.bfloat16:
hidden_states = hidden_states.float()
hidden_states = hidden_states.cpu().numpy()
attention_mask = attention_mask.cpu().numpy() if attention_mask is not None else None
for i, info in enumerate(infos):
hidden_states_i = hidden_states[i]
attention_mask_i = attention_mask[i] if attention_mask is not None else None
if self.cache_to_disk and hasattr(info, 'text_encoder_outputs_npz'):
np.savez(
info.text_encoder_outputs_npz,
hidden_states=hidden_states_i,
attention_mask=attention_mask_i,
apply_attention_mask=self.apply_attention_mask,
)
else:
info.text_encoder_outputs = (hidden_states_i, hidden_states_i, np.zeros((hidden_states_i.shape[0], 3), dtype=np.float32), attention_mask_i)
class CogView4LatentsCachingStrategy(LatentsCachingStrategy):
COGVIEW4_LATENTS_NPZ_SUFFIX = "_cogview4.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
"""Get the path for cached latents.
Args:
absolute_path: Absolute path to the source image
image_size: Tuple of (height, width) for the target resolution
Returns:
Path to the cached latents file
"""
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ CogView4LatentsCachingStrategy.COGVIEW4_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool
) -> bool:
"""Check if the latents are already cached and valid.
Args:
bucket_reso: Target resolution as (height, width)
npz_path: Path to the cached latents file
flip_aug: Whether flip augmentation was applied
alpha_mask: Whether alpha mask was used
Returns:
bool: True if valid cache exists, False otherwise
"""
# Using 8 as the default number of frames for compatibility
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk(
self,
npz_path: str,
bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""Load latents from disk.
Args:
npz_path: Path to the cached latents file
bucket_reso: Target resolution as (height, width)
Returns:
Tuple containing:
- latents: The loaded latents or None if loading failed
- original_size: Original image size as [height, width]
- crop_top_left: Crop offset as [top, left]
- alpha_mask: Alpha mask if available
- alpha_mask_origin: Original alpha mask if available
"""
# Using 8 as the default number of frames for compatibility
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(
self,
vae: Any,
image_infos: List[Any],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool
) -> None:
"""Cache a batch of latents.
Args:
vae: The VAE model used for encoding
image_infos: List of image information objects
flip_aug: Whether to apply flip augmentation
alpha_mask: Whether to use alpha mask
random_crop: Whether to apply random crop
"""
# Define encoding function that moves output to CPU
def encode_by_vae(img_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return vae.encode(img_tensor).to("cpu")
# Get VAE device and dtype
vae_device = vae.device
vae_dtype = vae.dtype
# Cache latents using the default implementation
self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
image_infos,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True
)
# Clean up GPU memory if not in high VRAM mode
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# Test code for CogView4TokenizeStrategy
tokenizer = CogView4TokenizeStrategy(512)
text = "hello world"
# Test single text tokenization
input_ids, attention_mask = tokenizer.tokenize(text)
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
# Test batch tokenization
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
batch_input_ids, batch_attention_mask = tokenizer.tokenize(texts)
print("\nBatch Input IDs:", batch_input_ids.shape)
print("Batch Attention Mask:", batch_attention_mask.shape)
# Test with a long text
long_text = ",".join(["hello world! this is long text"] * 10)
long_input_ids, long_attention_mask = tokenizer.tokenize(long_text)
print("\nLong text input IDs shape:", long_input_ids.shape)
print("Long text attention mask shape:", long_attention_mask.shape)
# Test text encoding strategy
print("\nTesting text encoding strategy...")
from transformers import AutoModel
# Load a small GLM model for testing
model = AutoModel.from_pretrained("THUDM/glm-10b-chinese", trust_remote_code=True)
model.eval()
encoding_strategy = CogView4TextEncodingStrategy()
tokens = tokenizer.tokenize(texts)
encoded = encoding_strategy.encode_tokens(tokenizer, [model], tokens)
print(f"Number of outputs: {len(encoded)}")
print(f"Hidden states shape: {encoded[0].shape}")
print(f"Attention mask shape: {encoded[3].shape if encoded[3] is not None else 'None'}")
# Test caching strategy
print("\nTesting caching strategy...")
import tempfile
import os
class DummyInfo:
def __init__(self, caption):
self.caption = caption
self.text_encoder_outputs_npz = tempfile.mktemp(suffix=".npz")
# Create test data
infos = [DummyInfo(text) for text in texts]
# Test caching
caching_strategy = CogView4TextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=2,
skip_disk_cache_validity_check=False
)
# Cache the outputs
caching_strategy.cache_batch_outputs(tokenizer, [model], encoding_strategy, infos)
# Check if files were created
for info in infos:
exists = os.path.exists(info.text_encoder_outputs_npz)
print(f"Cache file {info.text_encoder_outputs_npz} exists: {exists}")
# Clean up
for info in infos:
if os.path.exists(info.text_encoder_outputs_npz):
os.remove(info.text_encoder_outputs_npz)

View File

@@ -687,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset):
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
@@ -727,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset):
), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation'
self.resize_interpolation = resize_interpolation
self.skip_image_resolution = skip_image_resolution
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -1103,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all(
[
not (
subset.caption_dropout_rate > 0 and not cache_supports_dropout
subset.caption_dropout_rate > 0
and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -1915,8 +1919,15 @@ class DreamBoothDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2034,6 +2045,24 @@ class DreamBoothDataset(BaseDataset):
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
if self.skip_image_resolution is not None:
filtered_img_paths = []
filtered_sizes = []
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
for img_path, size in zip(img_paths, sizes):
if size is None: # no latents cache file, get image size by reading image file (slow)
size = self.get_image_size(img_path)
if size[0] * size[1] <= skip_image_area:
continue
filtered_img_paths.append(img_path)
filtered_sizes.append(size)
if len(filtered_img_paths) < len(img_paths):
logger.info(
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
)
img_paths = filtered_img_paths
sizes = filtered_sizes
# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
@@ -2059,7 +2088,7 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
captions = [metas[img_path]["caption"] for img_path in img_paths]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
@@ -2200,8 +2229,15 @@ class FineTuningDataset(BaseDataset):
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
@@ -2297,6 +2333,7 @@ class FineTuningDataset(BaseDataset):
tags_list = []
size_set_from_metadata = 0
size_set_from_cache_filename = 0
num_filtered = 0
for image_key in image_keys_sorted_by_length_desc:
img_md = metadata[image_key]
caption = img_md.get("caption")
@@ -2355,6 +2392,16 @@ class FineTuningDataset(BaseDataset):
image_info.image_size = (w, h)
size_set_from_cache_filename += 1
if self.skip_image_resolution is not None:
size = image_info.image_size
if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow)
size = self.get_image_size(abs_path)
image_info.image_size = size
skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1]
if size[0] * size[1] <= skip_image_area:
num_filtered += 1
continue
self.register_image(image_info, subset)
if size_set_from_cache_filename > 0:
@@ -2363,6 +2410,8 @@ class FineTuningDataset(BaseDataset):
)
if size_set_from_metadata > 0:
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
if num_filtered > 0:
logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}")
self.num_train_images += len(metadata) * subset.num_repeats
# TODO do not record tag freq when no tag
@@ -2387,8 +2436,15 @@ class ControlNetDataset(BaseDataset):
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
skip_image_resolution: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(
resolution,
network_multiplier,
debug_dataset,
resize_interpolation,
skip_image_resolution,
)
db_subsets = []
for subset in subsets:
@@ -2440,6 +2496,7 @@ class ControlNetDataset(BaseDataset):
validation_split,
validation_seed,
resize_interpolation,
skip_image_resolution,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2487,9 +2544,8 @@ class ControlNetDataset(BaseDataset):
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
if len(extra_imgs) > 0:
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -4601,6 +4657,13 @@ def add_dataset_arguments(
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--skip_image_resolution",
type=str,
default=None,
help="images not larger than this resolution will be skipped ('size' or 'width,height')"
" / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,
@@ -5414,6 +5477,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
len(args.resolution) == 2
), f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.skip_image_resolution is not None:
args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")])
if len(args.skip_image_resolution) == 1:
args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0])
assert (
len(args.skip_image_resolution) == 2
), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'','高さ'で指定してください: {args.skip_image_resolution}"
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")])
assert (
@@ -6183,10 +6254,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
name = names[lr_index]
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()):
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]:
logs["lr/d*eff_lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"]
)
# scheduler:

View File

@@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata):
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
index = int(torch.searchsorted(cumulative_sums, target))
index = max(0, min(index, len(S) - 1))
return index
@@ -69,8 +69,8 @@ def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
index = int(torch.searchsorted(sum_S_squared, target**2))
index = max(0, min(index, len(S) - 1))
return index
@@ -78,16 +78,23 @@ def index_sv_fro(S, target):
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
index = int(torch.sum(S > min_sv).item()) - 1
index = max(0, min(index, len(S) - 1))
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
weight = weight.reshape(out_size, -1)
_in_size = in_size * kernel_size * kernel_size
if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter)
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -103,10 +110,14 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
return param_dict
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048:
U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter)
Vh = V.T
else:
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -198,10 +209,9 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
return param_dict
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2):
max_old_rank = None
new_alpha = None
verbose_str = "\n"
fro_list = []
if dynamic_method:
@@ -262,10 +272,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter)
if verbose:
max_ratio = param_dict["max_ratio"]
@@ -274,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
verbose_str += f"{block_down_name:75} | "
verbose_str = f"{block_down_name:75} | "
verbose_str += (
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
)
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
if dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}"
tqdm.write(verbose_str)
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
@@ -297,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
del param_dict
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, max_old_rank, new_alpha
@@ -336,7 +343,7 @@ def resize(args):
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter
)
# update metadata
@@ -414,6 +421,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
parser.add_argument(
"--svd_lowrank_niter",
type=int,
default=2,
help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD"
" / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用",
)
return parser

View File

@@ -9,7 +9,7 @@ einops==0.7.0
bitsandbytes
lion-pytorch==0.2.3
schedulefree==1.4
pytorch-optimizer==3.9.0
pytorch-optimizer==3.10.0
prodigy-plus-schedule-free==1.9.2
prodigyopt==1.1.2
tensorboard

342
sdxl_train_leco.py Normal file
View File

@@ -0,0 +1,342 @@
import argparse
import importlib
import random
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
from library.leco_train_util import (
PromptEmbedsCache,
apply_noise_offset,
batch_add_time_ids,
build_network_kwargs,
concat_embeddings_xl,
diffusion_xl,
encode_prompt_sdxl,
get_add_time_ids,
get_initial_latents,
get_random_resolution,
load_prompt_settings,
predict_noise_xl,
save_weights,
)
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_optimizer_arguments(parser)
train_util.add_training_arguments(parser, support_dreambooth=False)
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
add_logging_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
parser.add_argument(
"--max_denoising_steps",
type=int,
default=40,
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
)
parser.add_argument(
"--leco_denoise_guidance_scale",
type=float,
default=3.0,
help="guidance scale for the partial denoising pass / 部分デイズ時のguidance scale",
)
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
)
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
return parser
def main():
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train_util.verify_training_args(args)
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
if args.output_dir is None:
raise ValueError("--output_dir is required")
if args.network_train_text_encoder_only:
raise ValueError("LECO does not support text encoder LoRA training")
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
set_seed(args.seed)
accelerator = train_util.prepare_accelerator(args)
weight_dtype, save_dtype = train_util.prepare_dtype(args)
prompt_settings = load_prompt_settings(args.prompts_file)
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
_, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
del vae
text_encoders = [text_encoder1, text_encoder2]
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.train()
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
prompt_cache = PromptEmbedsCache()
unique_prompts = sorted(
{
prompt
for setting in prompt_settings
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
}
)
with torch.no_grad():
for prompt in unique_prompts:
prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
for text_encoder in text_encoders:
text_encoder.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
network_module = importlib.import_module(args.network_module)
net_kwargs = build_network_kwargs(args)
if args.dim_from_weights:
if args.network_weights is None:
raise ValueError("--dim_from_weights requires --network_weights")
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
else:
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
None,
text_encoders,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
network.set_multiplier(0.0)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
logger.info(f"loaded network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing()
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
optimizer_train_fn()
train_util.init_trackers(accelerator, args, "sdxl_leco_train")
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
while global_step < args.max_train_steps:
with accelerator.accumulate(network):
optimizer.zero_grad(set_to_none=True)
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
height, width = get_random_resolution(setting)
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
accelerator.device, dtype=weight_dtype
)
latents = apply_noise_offset(latents, args.noise_offset)
add_time_ids = get_add_time_ids(
height,
width,
dynamic_crops=setting.dynamic_crops,
dtype=weight_dtype,
device=accelerator.device,
)
batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
denoised_latents = diffusion_xl(
unet,
noise_scheduler,
latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
add_time_ids=batched_time_ids,
total_timesteps=timesteps_to,
guidance_scale=args.leco_denoise_guidance_scale,
)
noise_scheduler.set_timesteps(1000, device=accelerator.device)
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
network_multiplier.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
positive_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
neutral_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
unconditional_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
target_latents = predict_noise_xl(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
add_time_ids=batched_time_ids,
guidance_scale=1.0,
)
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() * setting.weight
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(0.0)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"guidance_scale": setting.guidance_scale,
"network_multiplier": setting.multiplier,
}
accelerator.log(logs, step=global_step)
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,116 @@
from pathlib import Path
import torch
from library.leco_train_util import load_prompt_settings
def test_load_prompt_settings_with_original_format(tmp_path: Path):
prompt_file = tmp_path / "prompts.toml"
prompt_file.write_text(
"""
[[prompts]]
target = "van gogh"
guidance_scale = 1.5
resolution = 512
""".strip(),
encoding="utf-8",
)
prompts = load_prompt_settings(prompt_file)
assert len(prompts) == 1
assert prompts[0].target == "van gogh"
assert prompts[0].positive == "van gogh"
assert prompts[0].unconditional == ""
assert prompts[0].neutral == ""
assert prompts[0].action == "erase"
assert prompts[0].guidance_scale == 1.5
def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
prompt_file = tmp_path / "slider.toml"
prompt_file.write_text(
"""
guidance_scale = 2.0
resolution = 768
neutral = ""
[[targets]]
target_class = ""
positive = "high detail"
negative = "low detail"
multiplier = 1.25
weight = 0.5
""".strip(),
encoding="utf-8",
)
prompts = load_prompt_settings(prompt_file)
assert len(prompts) == 4
first = prompts[0]
second = prompts[1]
third = prompts[2]
fourth = prompts[3]
assert first.target == ""
assert first.positive == "low detail"
assert first.unconditional == "high detail"
assert first.action == "erase"
assert first.multiplier == 1.25
assert first.weight == 0.5
assert first.get_resolution() == (768, 768)
assert second.positive == "high detail"
assert second.unconditional == "low detail"
assert second.action == "enhance"
assert second.multiplier == 1.25
assert third.action == "erase"
assert third.multiplier == -1.25
assert fourth.action == "enhance"
assert fourth.multiplier == -1.25
def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
from library import sdxl_train_util
from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
class DummyScheduler:
def scale_model_input(self, latent_model_input, timestep):
return latent_model_input
class DummyUNet:
def __call__(self, x, timesteps, context, y):
self.x = x
self.timesteps = timesteps
self.context = context
self.y = y
return torch.zeros_like(x)
latents = torch.randn(1, 4, 8, 8)
prompt_embeds = PromptEmbedsXL(
text_embeds=torch.randn(2, 77, 2048),
pooled_embeds=torch.randn(2, 1280),
)
add_time_ids = torch.tensor(
[
[1024, 1024, 0, 0, 1024, 1024],
[1024, 1024, 0, 0, 1024, 1024],
],
dtype=prompt_embeds.pooled_embeds.dtype,
)
unet = DummyUNet()
noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
expected_size_embeddings = sdxl_train_util.get_size_embeddings(
add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
).to(prompt_embeds.pooled_embeds.dtype)
assert noise_pred.shape == latents.shape
assert unet.context is prompt_embeds.text_embeds
assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))

View File

@@ -0,0 +1,16 @@
import sdxl_train_leco
from library import deepspeed_utils, sdxl_train_util, train_util
def test_syntax():
assert sdxl_train_leco is not None
def test_setup_parser_supports_shared_training_validation():
args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
train_util.verify_training_args(args)
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
assert args.min_snr_gamma is None
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None

15
tests/test_train_leco.py Normal file
View File

@@ -0,0 +1,15 @@
import train_leco
from library import deepspeed_utils, train_util
def test_syntax():
assert train_leco is not None
def test_setup_parser_supports_shared_training_validation():
args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
train_util.verify_training_args(args)
assert args.min_snr_gamma is None
assert deepspeed_utils.prepare_deepspeed_plugin(args) is None

319
train_leco.py Normal file
View File

@@ -0,0 +1,319 @@
import argparse
import importlib
import random
import torch
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from tqdm import tqdm
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import custom_train_functions, strategy_sd, train_util
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
from library.leco_train_util import (
PromptEmbedsCache,
apply_noise_offset,
build_network_kwargs,
concat_embeddings,
diffusion,
encode_prompt_sd,
get_initial_latents,
get_random_resolution,
get_save_extension,
load_prompt_settings,
predict_noise,
save_weights,
)
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_optimizer_arguments(parser)
train_util.add_training_arguments(parser, support_dreambooth=False)
custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
add_logging_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
parser.add_argument(
"--max_denoising_steps",
type=int,
default=40,
help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
)
parser.add_argument(
"--leco_denoise_guidance_scale",
type=float,
default=3.0,
help="guidance scale for the partial denoising pass / 部分デイズ時のguidance scale",
)
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="unsupported for LECO; kept for compatibility / LECOでは未対応",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
)
parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
# dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
return parser
def main():
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train_util.verify_training_args(args)
if args.output_dir is None:
raise ValueError("--output_dir is required")
if args.network_train_text_encoder_only:
raise ValueError("LECO does not support text encoder LoRA training")
if args.seed is None:
args.seed = random.randint(0, 2**32 - 1)
set_seed(args.seed)
accelerator = train_util.prepare_accelerator(args)
weight_dtype, save_dtype = train_util.prepare_dtype(args)
prompt_settings = load_prompt_settings(args.prompts_file)
logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
del vae
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.train()
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
prompt_cache = PromptEmbedsCache()
unique_prompts = sorted(
{
prompt
for setting in prompt_settings
for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
}
)
with torch.no_grad():
for prompt in unique_prompts:
prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
text_encoder.to("cpu")
clean_memory_on_device(accelerator.device)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
network_module = importlib.import_module(args.network_module)
net_kwargs = build_network_kwargs(args)
if args.dim_from_weights:
if args.network_weights is None:
raise ValueError("--dim_from_weights requires --network_weights")
network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
None,
text_encoder,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
network.set_multiplier(0.0)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
logger.info(f"loaded network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing()
unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
optimizer_train_fn()
train_util.init_trackers(accelerator, args, "leco_train")
progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
while global_step < args.max_train_steps:
with accelerator.accumulate(network):
optimizer.zero_grad(set_to_none=True)
setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
height, width = get_random_resolution(setting)
latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
accelerator.device, dtype=weight_dtype
)
latents = apply_noise_offset(latents, args.noise_offset)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
denoised_latents = diffusion(
unet,
noise_scheduler,
latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
total_timesteps=timesteps_to,
guidance_scale=args.leco_denoise_guidance_scale,
)
noise_scheduler.set_timesteps(1000, device=accelerator.device)
current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
current_timestep = noise_scheduler.timesteps[current_timestep_index]
network_multiplier.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
positive_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
guidance_scale=1.0,
)
neutral_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
guidance_scale=1.0,
)
unconditional_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
guidance_scale=1.0,
)
network_multiplier.set_multiplier(setting.multiplier)
with accelerator.autocast():
target_latents = predict_noise(
unet,
noise_scheduler,
current_timestep,
denoised_latents,
concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
guidance_scale=1.0,
)
target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
loss = loss.mean(dim=(1, 2, 3))
if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() * setting.weight
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
network_multiplier = accelerator.unwrap_model(network)
network_multiplier.set_multiplier(0.0)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"guidance_scale": setting.guidance_scale,
"network_multiplier": setting.multiplier,
}
accelerator.log(logs, step=global_step)
progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -90,40 +90,23 @@ class NetworkTrainer:
if lr_descriptions is not None:
lr_desc = lr_descriptions[i]
else:
idx = i - (0 if args.network_train_unet_only else -1)
idx = i - (0 if args.network_train_unet_only else 1)
if idx == -1:
lr_desc = "textencoder"
else:
if len(lrs) > 2:
lr_desc = f"group{idx}"
lr_desc = f"group{i}"
else:
lr_desc = "unet"
logs[f"lr/{lr_desc}"] = lr
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
# tracking d*lr value
logs[f"lr/d*lr/{lr_desc}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet.
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()):
opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer
if opt is not None:
logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"]
if "effective_lr" in opt.param_groups[i]:
logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"]
return logs
@@ -1085,6 +1068,7 @@ class NetworkTrainer:
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"skip_image_resolution": dataset.skip_image_resolution,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
@@ -1191,6 +1175,7 @@ class NetworkTrainer:
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_skip_image_resolution": dataset.skip_image_resolution,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),