diff --git a/.gitignore b/.gitignore
index f5772a7f..79b9dc3d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,4 +11,5 @@ GEMINI.md
.claude
.gemini
MagicMock
-references
\ No newline at end of file
+.codex-tmp
+references
diff --git a/docs/train_leco.md b/docs/train_leco.md
new file mode 100644
index 00000000..0896c58c
--- /dev/null
+++ b/docs/train_leco.md
@@ -0,0 +1,736 @@
+# LECO Training Guide / LECO 学習ガイド
+
+LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only.
+
+This repository provides two LECO training scripts:
+
+- `train_leco.py` for Stable Diffusion 1.x / 2.x
+- `sdxl_train_leco.py` for SDXL
+
+
+日本語
+
+LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。
+
+このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています:
+
+- `train_leco.py` : Stable Diffusion 1.x / 2.x 用
+- `sdxl_train_leco.py` : SDXL 用
+
+
+## 1. Overview / 概要
+
+### What LECO Can Do / LECO でできること
+
+LECO can be used for:
+
+- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images)
+- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced)
+- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair")
+
+Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts.
+
+
+日本語
+
+LECO は以下の用途に使用できます:
+
+- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去)
+- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする)
+- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー)
+
+通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。
+
+
+### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い
+
+| | Standard LoRA | LECO |
+|---|---|---|
+| Training data | Image dataset required | **No images needed** |
+| Configuration | Dataset TOML | Prompt TOML |
+| Training target | U-Net and/or Text Encoder | **U-Net only** |
+| Training unit | Epochs and steps | **Steps only** |
+| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) |
+
+
+日本語
+
+| | 通常の LoRA | LECO |
+|---|---|---|
+| 学習データ | 画像データセットが必要 | **画像不要** |
+| 設定ファイル | データセット TOML | プロンプト TOML |
+| 学習対象 | U-Net と Text Encoder | **U-Net のみ** |
+| 学習単位 | エポックとステップ | **ステップのみ** |
+| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) |
+
+
+## 2. Prompt Configuration File / プロンプト設定ファイル
+
+LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style).
+
+
+日本語
+LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。
+
+
+### 2.1. Original LECO Format / オリジナル LECO 形式
+
+Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair.
+
+```toml
+[[prompts]]
+target = "van gogh"
+positive = "van gogh"
+unconditional = ""
+neutral = ""
+action = "erase"
+guidance_scale = 1.0
+resolution = 512
+batch_size = 1
+multiplier = 1.0
+weight = 1.0
+```
+
+Each `[[prompts]]` entry defines one training pair with the following fields:
+
+| Field | Required | Default | Description |
+|-------|----------|---------|-------------|
+| `target` | Yes | - | The concept to be modified by the LoRA |
+| `positive` | No | same as `target` | The "positive direction" prompt for building the training target |
+| `unconditional` | No | `""` | The unconditional/negative prompt |
+| `neutral` | No | `""` | The neutral baseline prompt |
+| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it |
+| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) |
+| `resolution` | No | `512` | Training resolution (int or `[height, width]`) |
+| `batch_size` | No | `1` | Number of latent samples per training step for this prompt |
+| `multiplier` | No | `1.0` | LoRA strength multiplier during training |
+| `weight` | No | `1.0` | Loss weight for this prompt pair |
+
+
+日本語
+
+`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。
+
+各 `[[prompts]]` エントリのフィールド:
+
+| フィールド | 必須 | デフォルト | 説明 |
+|-----------|------|-----------|------|
+| `target` | はい | - | LoRA によって変更される概念 |
+| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト |
+| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト |
+| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト |
+| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 |
+| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) |
+| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) |
+| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 |
+| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 |
+| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み |
+
+
+### 2.2. Slider Target Format / スライダーターゲット形式
+
+Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided).
+
+```toml
+guidance_scale = 1.0
+resolution = 1024
+neutral = ""
+
+[[targets]]
+target_class = "1girl"
+positive = "1girl, long hair"
+negative = "1girl, short hair"
+multiplier = 1.0
+weight = 1.0
+```
+
+Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets.
+
+Each `[[targets]]` entry supports the following fields:
+
+| Field | Required | Default | Description |
+|-------|----------|---------|-------------|
+| `target_class` | Yes | - | The base class/subject prompt |
+| `positive` | No* | `""` | Prompt for the positive direction of the slider |
+| `negative` | No* | `""` | Prompt for the negative direction of the slider |
+| `multiplier` | No | `1.0` | LoRA strength multiplier |
+| `weight` | No | `1.0` | Loss weight |
+
+\* At least one of `positive` or `negative` must be provided.
+
+
+日本語
+
+`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。
+
+トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。
+
+各 `[[targets]]` エントリのフィールド:
+
+| フィールド | 必須 | デフォルト | 説明 |
+|-----------|------|-----------|------|
+| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト |
+| `positive` | いいえ* | `""` | スライダーの正方向プロンプト |
+| `negative` | いいえ* | `""` | スライダーの負方向プロンプト |
+| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 |
+| `weight` | いいえ | `1.0` | loss 重み |
+
+\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。
+
+
+### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト
+
+You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization.
+
+```toml
+guidance_scale = 1.5
+resolution = 1024
+neutrals = ["", "photo of a person", "cinematic portrait"]
+
+[[targets]]
+target_class = "person"
+positive = "smiling person"
+negative = "expressionless person"
+```
+
+You can also load neutral prompts from a text file (one prompt per line):
+
+```toml
+neutral_prompt_file = "neutrals.txt"
+
+[[targets]]
+target_class = ""
+positive = "high detail"
+negative = "low detail"
+```
+
+
+日本語
+
+スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。
+
+ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。
+
+
+### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換
+
+If you have an existing ai-toolkit style YAML config, convert it to TOML as follows:
+
+
+日本語
+既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。
+
+
+**YAML:**
+```yaml
+targets:
+ - target_class: ""
+ positive: "high detail"
+ negative: "low detail"
+ multiplier: 1.0
+guidance_scale: 1.0
+resolution: 512
+```
+
+**TOML:**
+```toml
+guidance_scale = 1.0
+resolution = 512
+
+[[targets]]
+target_class = ""
+positive = "high detail"
+negative = "low detail"
+multiplier = 1.0
+```
+
+Key syntax differences:
+
+- Use `=` instead of `:` for key-value pairs
+- Use `[[targets]]` header instead of `targets:` with `- ` list items
+- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`)
+
+
+日本語
+
+主な構文の違い:
+
+- キーと値の区切りに `:` ではなく `=` を使用
+- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用
+- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`)
+
+
+## 3. Running the Training / 学習の実行
+
+Training is started by executing the script from the terminal. Below are basic command-line examples.
+
+In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`.
+
+
+日本語
+学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。
+
+実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。
+
+
+### SD 1.x / 2.x
+
+```bash
+accelerate launch --mixed_precision bf16 train_leco.py
+ --pretrained_model_name_or_path="model.safetensors"
+ --prompts_file="prompts.toml"
+ --output_dir="output"
+ --output_name="my_leco"
+ --network_dim=8
+ --network_alpha=4
+ --learning_rate=1e-4
+ --optimizer_type="AdamW8bit"
+ --max_train_steps=500
+ --max_denoising_steps=40
+ --mixed_precision=bf16
+ --sdpa
+ --gradient_checkpointing
+ --save_every_n_steps=100
+```
+
+### SDXL
+
+```bash
+accelerate launch --mixed_precision bf16 sdxl_train_leco.py
+ --pretrained_model_name_or_path="sdxl_model.safetensors"
+ --prompts_file="slider.toml"
+ --output_dir="output"
+ --output_name="my_sdxl_slider"
+ --network_dim=8
+ --network_alpha=4
+ --learning_rate=1e-4
+ --optimizer_type="AdamW8bit"
+ --max_train_steps=1000
+ --max_denoising_steps=40
+ --mixed_precision=bf16
+ --sdpa
+ --gradient_checkpointing
+ --save_every_n_steps=200
+```
+
+## 4. Command-Line Arguments / コマンドライン引数
+
+### 4.1. LECO-Specific Arguments / LECO 固有の引数
+
+These arguments are unique to LECO and not found in standard LoRA training scripts.
+
+
+日本語
+以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。
+
+
+* `--prompts_file="prompts.toml"` **[Required]**
+ * Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format.
+
+* `--max_denoising_steps=40`
+ * Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`.
+
+* `--leco_denoise_guidance_scale=3.0`
+ * Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`.
+
+
+日本語
+
+* `--prompts_file="prompts.toml"` **[必須]**
+ * LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。
+
+* `--max_denoising_steps=40`
+ * 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。
+
+* `--leco_denoise_guidance_scale=3.0`
+ * 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。
+
+
+#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い
+
+There are two separate guidance scale parameters that control different aspects of LECO training:
+
+1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal.
+
+2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target.
+
+If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`).
+
+
+日本語
+
+LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります:
+
+1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。
+
+2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。
+
+学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。
+
+
+### 4.2. Model Arguments / モデル引数
+
+* `--pretrained_model_name_or_path="model.safetensors"` **[Required]**
+ * Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID).
+
+* `--v2` (SD 1.x/2.x only)
+ * Specify when using a Stable Diffusion v2.x model.
+
+* `--v_parameterization` (SD 1.x/2.x only)
+ * Specify when using a v-prediction model (e.g., SD 2.x 768px models).
+
+
+日本語
+
+* `--pretrained_model_name_or_path="model.safetensors"` **[必須]**
+ * ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。
+
+* `--v2`(SD 1.x/2.x のみ)
+ * Stable Diffusion v2.x モデルを使用する場合に指定します。
+
+* `--v_parameterization`(SD 1.x/2.x のみ)
+ * v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。
+
+
+### 4.3. LoRA Network Arguments / LoRA ネットワーク引数
+
+* `--network_module=networks.lora`
+ * Network module to train. Default: `networks.lora`.
+
+* `--network_dim=8`
+ * LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`.
+
+* `--network_alpha=4`
+ * LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`.
+
+* `--network_dropout=0.1`
+ * Dropout rate for LoRA layers. Optional.
+
+* `--network_args "key=value" ...`
+ * Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA.
+
+* `--network_weights="path/to/weights.safetensors"`
+ * Load pretrained LoRA weights to continue training.
+
+* `--dim_from_weights`
+ * Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`.
+
+
+日本語
+
+* `--network_module=networks.lora`
+ * 学習するネットワークモジュール。デフォルト: `networks.lora`。
+
+* `--network_dim=8`
+ * LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。
+
+* `--network_alpha=4`
+ * 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。
+
+* `--network_dropout=0.1`
+ * LoRA レイヤーのドロップアウト率。省略可。
+
+* `--network_args "key=value" ...`
+ * ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。
+
+* `--network_weights="path/to/weights.safetensors"`
+ * 事前学習済み LoRA ウェイトを読み込んで学習を続行します。
+
+* `--dim_from_weights`
+ * `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。
+
+
+### 4.4. Training Parameters / 学習パラメータ
+
+* `--max_train_steps=500`
+ * Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`.
+ * Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only).
+
+* `--learning_rate=1e-4`
+ * Learning rate. Typical range for LECO: `1e-4` to `1e-3`.
+
+* `--unet_lr=1e-4`
+ * Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used.
+
+* `--optimizer_type="AdamW8bit"`
+ * Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc.
+
+* `--lr_scheduler="constant"`
+ * Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc.
+
+* `--lr_warmup_steps=0`
+ * Number of warmup steps for the learning rate scheduler.
+
+* `--gradient_accumulation_steps=1`
+ * Number of steps to accumulate gradients before updating. Effectively multiplies the batch size.
+
+* `--max_grad_norm=1.0`
+ * Maximum gradient norm for gradient clipping. Set to `0` to disable.
+
+* `--min_snr_gamma=5.0`
+ * Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional.
+
+
+日本語
+
+* `--max_train_steps=500`
+ * 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。
+ * 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。
+
+* `--learning_rate=1e-4`
+ * 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。
+
+* `--unet_lr=1e-4`
+ * U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。
+
+* `--optimizer_type="AdamW8bit"`
+ * オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。
+
+* `--lr_scheduler="constant"`
+ * 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。
+
+* `--lr_warmup_steps=0`
+ * 学習率スケジューラのウォームアップステップ数。
+
+* `--gradient_accumulation_steps=1`
+ * 勾配を累積するステップ数。実質的にバッチサイズを増加させます。
+
+* `--max_grad_norm=1.0`
+ * 勾配クリッピングの最大勾配ノルム。`0` で無効化。
+
+* `--min_snr_gamma=5.0`
+ * Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。
+
+
+### 4.5. Output and Save Arguments / 出力・保存引数
+
+* `--output_dir="output"` **[Required]**
+ * Directory for saving trained LoRA models and logs.
+
+* `--output_name="my_leco"` **[Required]**
+ * Base filename for the trained LoRA (without extension).
+
+* `--save_model_as="safetensors"`
+ * Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`.
+
+* `--save_every_n_steps=100`
+ * Save an intermediate checkpoint every N steps. If not specified, only the final model is saved.
+ * Note: `--save_every_n_epochs` is **not supported** for LECO.
+
+* `--save_precision="fp16"`
+ * Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used.
+
+* `--no_metadata`
+ * Do not write metadata into the saved model file.
+
+* `--training_comment="my comment"`
+ * A comment string stored in the model metadata.
+
+
+日本語
+
+* `--output_dir="output"` **[必須]**
+ * 学習済み LoRA モデルとログの保存先ディレクトリ。
+
+* `--output_name="my_leco"` **[必須]**
+ * 学習済み LoRA のベースファイル名(拡張子なし)。
+
+* `--save_model_as="safetensors"`
+ * モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。
+
+* `--save_every_n_steps=100`
+ * N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。
+ * 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。
+
+* `--save_precision="fp16"`
+ * モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。
+
+* `--no_metadata`
+ * 保存するモデルファイルにメタデータを書き込みません。
+
+* `--training_comment="my comment"`
+ * モデルのメタデータに保存されるコメント文字列。
+
+
+### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数
+
+* `--mixed_precision="bf16"`
+ * Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended.
+
+* `--full_fp16`
+ * Train entirely in fp16 precision including gradients.
+
+* `--full_bf16`
+ * Train entirely in bf16 precision including gradients.
+
+* `--gradient_checkpointing`
+ * Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions.
+
+* `--sdpa`
+ * Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended.
+
+* `--xformers`
+ * Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`.
+
+* `--mem_eff_attn`
+ * Use memory-efficient attention implementation. Another alternative to `--sdpa`.
+
+
+日本語
+
+* `--mixed_precision="bf16"`
+ * 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。
+
+* `--full_fp16`
+ * 勾配を含め全体を fp16 精度で学習します。
+
+* `--full_bf16`
+ * 勾配を含め全体を bf16 精度で学習します。
+
+* `--gradient_checkpointing`
+ * gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。
+
+* `--sdpa`
+ * Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。
+
+* `--xformers`
+ * xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。
+
+* `--mem_eff_attn`
+ * メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。
+
+
+### 4.7. Other Useful Arguments / その他の便利な引数
+
+* `--seed=42`
+ * Random seed for reproducibility. If not specified, a random seed is automatically generated.
+
+* `--noise_offset=0.05`
+ * Enable noise offset. Small values like `0.02` to `0.1` can help with training stability.
+
+* `--zero_terminal_snr`
+ * Fix noise scheduler betas to enforce zero terminal SNR.
+
+* `--clip_skip=2` (SD 1.x/2.x only)
+ * Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`.
+
+* `--logging_dir="logs"`
+ * Directory for TensorBoard logs. Enables logging when specified.
+
+* `--log_with="tensorboard"`
+ * Logging tool. Options: `tensorboard`, `wandb`, `all`.
+
+
+日本語
+
+* `--seed=42`
+ * 再現性のための乱数シード。指定しない場合は自動生成されます。
+
+* `--noise_offset=0.05`
+ * ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。
+
+* `--zero_terminal_snr`
+ * noise scheduler の betas を修正してゼロ終端 SNR を強制します。
+
+* `--clip_skip=2`(SD 1.x/2.x のみ)
+ * text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。
+
+* `--logging_dir="logs"`
+ * TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。
+
+* `--log_with="tensorboard"`
+ * ログツール。`tensorboard`、`wandb`、`all` から選択。
+
+
+## 5. Tips / ヒント
+
+### Tuning the Effect Strength / 効果の強さの調整
+
+If the trained LoRA has a weak or unnoticeable effect:
+
+1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect.
+2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`).
+3. **Increase `--max_denoising_steps`** for more refined intermediate latents.
+4. **Increase `--max_train_steps`** to train longer.
+5. **Apply the LoRA with a higher weight** at inference time.
+
+
+日本語
+
+学習した LoRA の効果が弱い、または認識できない場合:
+
+1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。
+2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。
+3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。
+4. **`--max_train_steps` を増やして**、より長く学習する。
+5. **推論時に LoRA のウェイトを大きくして**適用する。
+
+
+### Recommended Starting Settings / 推奨の開始設定
+
+| Parameter | SD 1.x/2.x | SDXL |
+|-----------|-------------|------|
+| `--network_dim` | `4`-`8` | `8`-`16` |
+| `--learning_rate` | `1e-4` | `1e-4` |
+| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
+| `resolution` (in TOML) | `512` | `1024` |
+| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` |
+| `batch_size` (in TOML) | `1`-`4` | `1`-`4` |
+
+
+日本語
+
+| パラメータ | SD 1.x/2.x | SDXL |
+|-----------|-------------|------|
+| `--network_dim` | `4`-`8` | `8`-`16` |
+| `--learning_rate` | `1e-4` | `1e-4` |
+| `--max_train_steps` | `300`-`1000` | `500`-`2000` |
+| `resolution`(TOML内) | `512` | `1024` |
+| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` |
+| `batch_size`(TOML内) | `1`-`4` | `1`-`4` |
+
+
+### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL)
+
+For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file:
+
+```toml
+resolution = 1024
+dynamic_resolution = true
+dynamic_crops = true
+
+[[targets]]
+target_class = ""
+positive = "high detail"
+negative = "low detail"
+```
+
+- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets.
+- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings.
+
+These options can improve the LoRA's generalization across different aspect ratios.
+
+
+日本語
+
+SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。
+
+- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。
+- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。
+
+これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。
+
+
+## 6. Using the Trained Model / 学習済みモデルの利用
+
+The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc.
+
+For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction.
+
+
+日本語
+
+学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。
+
+スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。
+
diff --git a/library/leco_train_util.py b/library/leco_train_util.py
new file mode 100644
index 00000000..5e95c163
--- /dev/null
+++ b/library/leco_train_util.py
@@ -0,0 +1,522 @@
+import argparse
+import json
+import math
+import os
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import torch
+import toml
+from torch.utils.checkpoint import checkpoint
+
+from library import train_util
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]:
+ kwargs = {}
+ if args.network_args:
+ for net_arg in args.network_args:
+ key, value = net_arg.split("=", 1)
+ kwargs[key] = value
+ if "dropout" not in kwargs:
+ kwargs["dropout"] = args.network_dropout
+ return kwargs
+
+
+def get_save_extension(args: argparse.Namespace) -> str:
+ if args.save_model_as == "ckpt":
+ return ".ckpt"
+ if args.save_model_as == "pt":
+ return ".pt"
+ return ".safetensors"
+
+
+def save_weights(
+ accelerator,
+ network,
+ args: argparse.Namespace,
+ save_dtype,
+ prompt_settings,
+ global_step: int,
+ last: bool = False,
+ extra_metadata: Optional[Dict[str, str]] = None,
+) -> None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ ext = get_save_extension(args)
+ ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step)
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
+
+ metadata = None
+ if not args.no_metadata:
+ metadata = {
+ "ss_network_module": args.network_module,
+ "ss_network_dim": str(args.network_dim),
+ "ss_network_alpha": str(args.network_alpha),
+ "ss_leco_prompt_count": str(len(prompt_settings)),
+ "ss_leco_prompts_file": os.path.basename(args.prompts_file),
+ }
+ if extra_metadata:
+ metadata.update(extra_metadata)
+ if args.training_comment:
+ metadata["ss_training_comment"] = args.training_comment
+ metadata["ss_leco_preview"] = json.dumps(
+ [
+ {
+ "target": p.target,
+ "positive": p.positive,
+ "unconditional": p.unconditional,
+ "neutral": p.neutral,
+ "action": p.action,
+ "multiplier": p.multiplier,
+ "weight": p.weight,
+ }
+ for p in prompt_settings[:16]
+ ],
+ ensure_ascii=False,
+ )
+
+ unwrapped = accelerator.unwrap_model(network)
+ unwrapped.save_weights(ckpt_file, save_dtype, metadata)
+ logger.info(f"saved model to: {ckpt_file}")
+
+
+
+ResolutionValue = Union[int, Tuple[int, int]]
+
+
+@dataclass
+class PromptEmbedsXL:
+ text_embeds: torch.Tensor
+ pooled_embeds: torch.Tensor
+
+
+class PromptEmbedsCache:
+ def __init__(self):
+ self.prompts: dict[str, Any] = {}
+
+ def __setitem__(self, name: str, value: Any) -> None:
+ self.prompts[name] = value
+
+ def __getitem__(self, name: str) -> Any:
+ return self.prompts[name]
+
+
+@dataclass
+class PromptSettings:
+ target: str
+ positive: Optional[str] = None
+ unconditional: str = ""
+ neutral: Optional[str] = None
+ action: str = "erase"
+ guidance_scale: float = 1.0
+ resolution: ResolutionValue = 512
+ dynamic_resolution: bool = False
+ batch_size: int = 1
+ dynamic_crops: bool = False
+ multiplier: float = 1.0
+ weight: float = 1.0
+
+ def __post_init__(self):
+ if self.positive is None:
+ self.positive = self.target
+ if self.neutral is None:
+ self.neutral = self.unconditional
+ if self.action not in ("erase", "enhance"):
+ raise ValueError(f"Invalid action: {self.action}")
+
+ self.guidance_scale = float(self.guidance_scale)
+ self.batch_size = int(self.batch_size)
+ self.multiplier = float(self.multiplier)
+ self.weight = float(self.weight)
+ self.dynamic_resolution = bool(self.dynamic_resolution)
+ self.dynamic_crops = bool(self.dynamic_crops)
+ self.resolution = normalize_resolution(self.resolution)
+
+ def get_resolution(self) -> Tuple[int, int]:
+ if isinstance(self.resolution, tuple):
+ return self.resolution
+ return (self.resolution, self.resolution)
+
+ def build_target(self, positive_latents, neutral_latents, unconditional_latents):
+ offset = self.guidance_scale * (positive_latents - unconditional_latents)
+ if self.action == "erase":
+ return neutral_latents - offset
+ return neutral_latents + offset
+
+
+def normalize_resolution(value: Any) -> ResolutionValue:
+ if isinstance(value, tuple):
+ if len(value) != 2:
+ raise ValueError(f"resolution tuple must have 2 items: {value}")
+ return (int(value[0]), int(value[1]))
+ if isinstance(value, list):
+ if len(value) == 2 and all(isinstance(v, (int, float)) for v in value):
+ return (int(value[0]), int(value[1]))
+ raise ValueError(f"resolution list must have 2 numeric items: {value}")
+ return int(value)
+
+
+def _read_non_empty_lines(path: Union[str, Path]) -> List[str]:
+ with open(path, "r", encoding="utf-8") as f:
+ return [line.strip() for line in f.readlines() if line.strip()]
+
+
+def _recognized_prompt_keys() -> set[str]:
+ return {
+ "target",
+ "positive",
+ "unconditional",
+ "neutral",
+ "action",
+ "guidance_scale",
+ "resolution",
+ "dynamic_resolution",
+ "batch_size",
+ "dynamic_crops",
+ "multiplier",
+ "weight",
+ }
+
+
+def _recognized_slider_keys() -> set[str]:
+ return {
+ "target_class",
+ "positive",
+ "negative",
+ "neutral",
+ "guidance_scale",
+ "resolution",
+ "resolutions",
+ "dynamic_resolution",
+ "batch_size",
+ "dynamic_crops",
+ "multiplier",
+ "weight",
+ }
+
+
+def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]:
+ merged = {k: v for k, v in defaults.items() if k in known_keys}
+ merged.update(item)
+ return merged
+
+
+def _normalize_resolution_values(value: Any) -> List[ResolutionValue]:
+ if value is None:
+ return [512]
+ if isinstance(value, list) and value and isinstance(value[0], (list, tuple)):
+ return [normalize_resolution(v) for v in value]
+ return [normalize_resolution(value)]
+
+
+def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]:
+ target_class = str(target.get("target_class", ""))
+ positive = str(target.get("positive", "") or "")
+ negative = str(target.get("negative", "") or "")
+ multiplier = target.get("multiplier", 1.0)
+ resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512)))
+
+ if not positive.strip() and not negative.strip():
+ raise ValueError("slider target requires either positive or negative prompt")
+
+ base = dict(
+ target=target_class,
+ neutral=neutral,
+ guidance_scale=target.get("guidance_scale", 1.0),
+ dynamic_resolution=target.get("dynamic_resolution", False),
+ batch_size=target.get("batch_size", 1),
+ dynamic_crops=target.get("dynamic_crops", False),
+ weight=target.get("weight", 1.0),
+ )
+
+ # Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs.
+ # With both positive and negative: 4 pairs; with only one: 2 pairs.
+ pairs: list[tuple[str, str, str, float]] = []
+ if positive.strip() and negative.strip():
+ pairs = [
+ (negative, positive, "erase", multiplier),
+ (positive, negative, "enhance", multiplier),
+ (positive, negative, "erase", -multiplier),
+ (negative, positive, "enhance", -multiplier),
+ ]
+ elif negative.strip():
+ pairs = [
+ (negative, "", "erase", multiplier),
+ (negative, "", "enhance", -multiplier),
+ ]
+ else:
+ pairs = [
+ (positive, "", "enhance", multiplier),
+ (positive, "", "erase", -multiplier),
+ ]
+
+ prompt_settings: List[PromptSettings] = []
+ for resolution in resolutions:
+ for pos, uncond, action, mult in pairs:
+ prompt_settings.append(
+ PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult)
+ )
+
+ return prompt_settings
+
+
+def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
+ path = Path(path)
+ with open(path, "r", encoding="utf-8") as f:
+ data = toml.load(f)
+
+ if not data:
+ raise ValueError("prompt file is empty")
+
+ default_prompt_values = {
+ "guidance_scale": 1.0,
+ "resolution": 512,
+ "dynamic_resolution": False,
+ "batch_size": 1,
+ "dynamic_crops": False,
+ "multiplier": 1.0,
+ "weight": 1.0,
+ }
+
+ prompt_settings: List[PromptSettings] = []
+
+ def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None:
+ merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys())
+ prompt_settings.append(PromptSettings(**merged))
+
+ def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None:
+ merged = _merge_known_defaults(defaults, item, _recognized_slider_keys())
+ if not neutral_values:
+ neutral_values = [str(merged.get("neutral", "") or "")]
+ for neutral in neutral_values:
+ prompt_settings.extend(_expand_slider_target(merged, neutral))
+
+ if "prompts" in data:
+ defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
+ for item in data["prompts"]:
+ if "target_class" in item:
+ append_slider_item(item, defaults, [str(item.get("neutral", "") or "")])
+ else:
+ append_prompt_item(item, defaults)
+ else:
+ slider_config = data.get("slider", data)
+ targets = slider_config.get("targets")
+ if targets is None:
+ if "target_class" in slider_config:
+ targets = [slider_config]
+ elif "target" in slider_config:
+ targets = [slider_config]
+ else:
+ raise ValueError("prompt file does not contain prompts or slider targets")
+ if len(targets) == 0:
+ raise ValueError("prompt file contains an empty targets list")
+
+ if "target" in targets[0]:
+ defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}}
+ for item in targets:
+ append_prompt_item(item, defaults)
+ else:
+ defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}}
+ neutral_values: List[str] = []
+ if "neutrals" in slider_config:
+ neutral_values.extend(str(v) for v in slider_config["neutrals"])
+ if "neutral_prompt_file" in slider_config:
+ neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"]))
+ if "prompt_file" in slider_config:
+ neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"]))
+ if not neutral_values:
+ neutral_values = [str(slider_config.get("neutral", "") or "")]
+
+ for item in targets:
+ item_neutrals = neutral_values
+ if "neutrals" in item:
+ item_neutrals = [str(v) for v in item["neutrals"]]
+ elif "neutral_prompt_file" in item:
+ item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"])
+ elif "prompt_file" in item:
+ item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"])
+ elif "neutral" in item:
+ item_neutrals = [str(item["neutral"] or "")]
+
+ append_slider_item(item, defaults, item_neutrals)
+
+ if not prompt_settings:
+ raise ValueError("no prompt settings found")
+
+ return prompt_settings
+
+
+def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor:
+ tokens = tokenize_strategy.tokenize(prompt)
+ return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0]
+
+
+def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL:
+ tokens = tokenize_strategy.tokenize(prompt)
+ hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens)
+ return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2)
+
+
+def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor:
+ if noise_offset is None:
+ return latents
+ noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu")
+ noise = noise.to(dtype=latents.dtype, device=latents.device)
+ return latents + noise_offset * noise
+
+
+def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor:
+ noise = torch.randn(
+ (batch_size, 4, height // 8, width // 8),
+ device="cpu",
+ ).repeat(n_prompts, 1, 1, 1)
+ return noise * scheduler.init_noise_sigma
+
+
+def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
+ return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
+
+
+def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
+ text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
+ pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
+ return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
+
+
+def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor:
+ """Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch."""
+ return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0)
+
+
+def _run_with_checkpoint(function, *args):
+ if torch.is_grad_enabled():
+ return checkpoint(function, *args, use_reentrant=False)
+ return function(*args)
+
+
+def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0):
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
+
+ def run_unet(model_input, encoder_hidden_states):
+ return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample
+
+ noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+
+def diffusion(
+ unet,
+ scheduler,
+ latents: torch.Tensor,
+ text_embeddings: torch.Tensor,
+ total_timesteps: int,
+ start_timesteps: int = 0,
+ guidance_scale: float = 3.0,
+):
+ for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
+ noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale)
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
+ return latents
+
+
+def get_add_time_ids(
+ height: int,
+ width: int,
+ dynamic_crops: bool = False,
+ dtype: torch.dtype = torch.float32,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ if dynamic_crops:
+ random_scale = torch.rand(1).item() * 2 + 1
+ original_size = (int(height * random_scale), int(width * random_scale))
+ crops_coords_top_left = (
+ torch.randint(0, max(original_size[0] - height, 1), (1,)).item(),
+ torch.randint(0, max(original_size[1] - width, 1), (1,)).item(),
+ )
+ target_size = (height, width)
+ else:
+ original_size = (height, width)
+ crops_coords_top_left = (0, 0)
+ target_size = (height, width)
+
+ add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype)
+ if device is not None:
+ add_time_ids = add_time_ids.to(device)
+ return add_time_ids
+
+
+def predict_noise_xl(
+ unet,
+ scheduler,
+ timestep,
+ latents: torch.Tensor,
+ prompt_embeds: PromptEmbedsXL,
+ add_time_ids: torch.Tensor,
+ guidance_scale: float = 1.0,
+):
+ latent_model_input = torch.cat([latents] * 2)
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
+
+ orig_size = add_time_ids[:, :2]
+ crop_size = add_time_ids[:, 2:4]
+ target_size = add_time_ids[:, 4:6]
+ from library import sdxl_train_util
+
+ size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device)
+ vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1)
+
+ def run_unet(model_input, text_embeds, vector_embeds):
+ return unet(model_input, timestep, text_embeds, vector_embeds)
+
+ noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding)
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+
+def diffusion_xl(
+ unet,
+ scheduler,
+ latents: torch.Tensor,
+ prompt_embeds: PromptEmbedsXL,
+ add_time_ids: torch.Tensor,
+ total_timesteps: int,
+ start_timesteps: int = 0,
+ guidance_scale: float = 3.0,
+):
+ for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
+ noise_pred = predict_noise_xl(
+ unet,
+ scheduler,
+ timestep,
+ latents,
+ prompt_embeds=prompt_embeds,
+ add_time_ids=add_time_ids,
+ guidance_scale=guidance_scale,
+ )
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
+ return latents
+
+
+def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]:
+ max_resolution = bucket_resolution
+ min_resolution = bucket_resolution // 2
+ step = 64
+ min_step = min_resolution // step
+ max_step = max_resolution // step
+ height = torch.randint(min_step, max_step + 1, (1,)).item() * step
+ width = torch.randint(min_step, max_step + 1, (1,)).item() * step
+ return height, width
+
+
+def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]:
+ height, width = prompt.get_resolution()
+ if prompt.dynamic_resolution and height == width:
+ return get_random_resolution_in_bucket(height)
+ return height, width
diff --git a/library/train_util.py b/library/train_util.py
index 672aa597..83d04f5e 100644
--- a/library/train_util.py
+++ b/library/train_util.py
@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all(
[
not (
- subset.caption_dropout_rate > 0 and not cache_supports_dropout
+ subset.caption_dropout_rate > 0
+ and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset):
filtered_img_paths.append(img_path)
filtered_sizes.append(size)
if len(filtered_img_paths) < len(img_paths):
- logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
+ logger.info(
+ f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
+ )
img_paths = filtered_img_paths
sizes = filtered_sizes
@@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset):
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
if len(extra_imgs) > 0:
- logger.warning(
- f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
- )
+ logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
self.conditioning_image_transforms = IMAGE_TRANSFORMS
diff --git a/sdxl_train_leco.py b/sdxl_train_leco.py
new file mode 100644
index 00000000..ff5550f9
--- /dev/null
+++ b/sdxl_train_leco.py
@@ -0,0 +1,342 @@
+import argparse
+import importlib
+import random
+
+import torch
+from accelerate.utils import set_seed
+from diffusers import DDPMScheduler
+from tqdm import tqdm
+
+from library.device_utils import init_ipex, clean_memory_on_device
+
+init_ipex()
+
+from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util
+from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
+from library.leco_train_util import (
+ PromptEmbedsCache,
+ apply_noise_offset,
+ batch_add_time_ids,
+ build_network_kwargs,
+ concat_embeddings_xl,
+ diffusion_xl,
+ encode_prompt_sdxl,
+ get_add_time_ids,
+ get_initial_latents,
+ get_random_resolution,
+ load_prompt_settings,
+ predict_noise_xl,
+ save_weights,
+)
+from library.utils import add_logging_arguments, setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(parser)
+ train_util.add_optimizer_arguments(parser)
+ train_util.add_training_arguments(parser, support_dreambooth=False)
+ custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
+ sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
+ add_logging_arguments(parser)
+
+ parser.add_argument(
+ "--save_model_as",
+ type=str,
+ default="safetensors",
+ choices=[None, "ckpt", "pt", "safetensors"],
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
+
+ parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
+ parser.add_argument(
+ "--max_denoising_steps",
+ type=int,
+ default=40,
+ help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
+ )
+ parser.add_argument(
+ "--leco_denoise_guidance_scale",
+ type=float,
+ default=3.0,
+ help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
+ )
+
+ parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
+ parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
+ parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
+ parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
+ parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
+ parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
+ parser.add_argument(
+ "--network_train_text_encoder_only",
+ action="store_true",
+ help="unsupported for LECO; kept for compatibility / LECOでは未対応",
+ )
+ parser.add_argument(
+ "--network_train_unet_only",
+ action="store_true",
+ help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
+ )
+ parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
+ parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
+
+ # dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
+ parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
+ parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
+ parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
+
+ return parser
+
+
+def main():
+ parser = setup_parser()
+ args = parser.parse_args()
+ args = train_util.read_config_from_file(args, parser)
+ train_util.verify_training_args(args)
+ sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
+
+ if args.output_dir is None:
+ raise ValueError("--output_dir is required")
+ if args.network_train_text_encoder_only:
+ raise ValueError("LECO does not support text encoder LoRA training")
+
+ if args.seed is None:
+ args.seed = random.randint(0, 2**32 - 1)
+ set_seed(args.seed)
+
+ accelerator = train_util.prepare_accelerator(args)
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
+
+ prompt_settings = load_prompt_settings(args.prompts_file)
+ logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
+
+ _, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model(
+ args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
+ )
+ del vae
+ text_encoders = [text_encoder1, text_encoder2]
+
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
+ unet.requires_grad_(False)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ unet.train()
+
+ tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
+ text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
+
+ for text_encoder in text_encoders:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+
+ prompt_cache = PromptEmbedsCache()
+ unique_prompts = sorted(
+ {
+ prompt
+ for setting in prompt_settings
+ for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
+ }
+ )
+ with torch.no_grad():
+ for prompt in unique_prompts:
+ prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt)
+
+ for text_encoder in text_encoders:
+ text_encoder.to("cpu", dtype=torch.float32)
+ clean_memory_on_device(accelerator.device)
+
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000,
+ clip_sample=False,
+ )
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
+ if args.zero_terminal_snr:
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
+
+ network_module = importlib.import_module(args.network_module)
+ net_kwargs = build_network_kwargs(args)
+ if args.dim_from_weights:
+ if args.network_weights is None:
+ raise ValueError("--dim_from_weights requires --network_weights")
+ network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs)
+ else:
+ network = network_module.create_network(
+ 1.0,
+ args.network_dim,
+ args.network_alpha,
+ None,
+ text_encoders,
+ unet,
+ neuron_dropout=args.network_dropout,
+ **net_kwargs,
+ )
+
+ network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True)
+ network.set_multiplier(0.0)
+
+ if args.network_weights is not None:
+ info = network.load_weights(args.network_weights)
+ logger.info(f"loaded network weights from {args.network_weights}: {info}")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ network.enable_gradient_checkpointing()
+
+ unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
+ trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+
+ network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
+ accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet)
+
+ if args.full_fp16:
+ train_util.patch_accelerator_for_fp16_training(accelerator)
+
+ optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
+ optimizer_train_fn()
+ train_util.init_trackers(accelerator, args, "sdxl_leco_train")
+
+ progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
+ global_step = 0
+
+ while global_step < args.max_train_steps:
+ with accelerator.accumulate(network):
+ optimizer.zero_grad(set_to_none=True)
+
+ setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
+ noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
+
+ timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
+ height, width = get_random_resolution(setting)
+
+ latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
+ accelerator.device, dtype=weight_dtype
+ )
+ latents = apply_noise_offset(latents, args.noise_offset)
+ add_time_ids = get_add_time_ids(
+ height,
+ width,
+ dynamic_crops=setting.dynamic_crops,
+ dtype=weight_dtype,
+ device=accelerator.device,
+ )
+ batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size)
+
+ network_multiplier = accelerator.unwrap_model(network)
+ network_multiplier.set_multiplier(setting.multiplier)
+ with accelerator.autocast():
+ denoised_latents = diffusion_xl(
+ unet,
+ noise_scheduler,
+ latents,
+ concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
+ add_time_ids=batched_time_ids,
+ total_timesteps=timesteps_to,
+ guidance_scale=args.leco_denoise_guidance_scale,
+ )
+
+ noise_scheduler.set_timesteps(1000, device=accelerator.device)
+ current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
+ current_timestep = noise_scheduler.timesteps[current_timestep_index]
+
+ network_multiplier.set_multiplier(0.0)
+ with torch.no_grad(), accelerator.autocast():
+ positive_latents = predict_noise_xl(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
+ add_time_ids=batched_time_ids,
+ guidance_scale=1.0,
+ )
+ neutral_latents = predict_noise_xl(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
+ add_time_ids=batched_time_ids,
+ guidance_scale=1.0,
+ )
+ unconditional_latents = predict_noise_xl(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
+ add_time_ids=batched_time_ids,
+ guidance_scale=1.0,
+ )
+
+ network_multiplier.set_multiplier(setting.multiplier)
+ with accelerator.autocast():
+ target_latents = predict_noise_xl(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
+ add_time_ids=batched_time_ids,
+ guidance_scale=1.0,
+ )
+
+ target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
+ loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=(1, 2, 3))
+ if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
+ timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
+ loss = loss.mean() * setting.weight
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+ accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+
+ if accelerator.sync_gradients:
+ global_step += 1
+ progress_bar.update(1)
+ network_multiplier = accelerator.unwrap_model(network)
+ network_multiplier.set_multiplier(0.0)
+
+ logs = {
+ "loss": loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ "guidance_scale": setting.guidance_scale,
+ "network_multiplier": setting.multiplier,
+ }
+ accelerator.log(logs, step=global_step)
+ progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
+
+ if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
+ save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra)
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0}
+ save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/library/test_leco_train_util.py b/tests/library/test_leco_train_util.py
new file mode 100644
index 00000000..5e950f43
--- /dev/null
+++ b/tests/library/test_leco_train_util.py
@@ -0,0 +1,116 @@
+from pathlib import Path
+
+import torch
+
+from library.leco_train_util import load_prompt_settings
+
+
+def test_load_prompt_settings_with_original_format(tmp_path: Path):
+ prompt_file = tmp_path / "prompts.toml"
+ prompt_file.write_text(
+ """
+[[prompts]]
+target = "van gogh"
+guidance_scale = 1.5
+resolution = 512
+""".strip(),
+ encoding="utf-8",
+ )
+
+ prompts = load_prompt_settings(prompt_file)
+
+ assert len(prompts) == 1
+ assert prompts[0].target == "van gogh"
+ assert prompts[0].positive == "van gogh"
+ assert prompts[0].unconditional == ""
+ assert prompts[0].neutral == ""
+ assert prompts[0].action == "erase"
+ assert prompts[0].guidance_scale == 1.5
+
+
+def test_load_prompt_settings_with_slider_targets(tmp_path: Path):
+ prompt_file = tmp_path / "slider.toml"
+ prompt_file.write_text(
+ """
+guidance_scale = 2.0
+resolution = 768
+neutral = ""
+
+[[targets]]
+target_class = ""
+positive = "high detail"
+negative = "low detail"
+multiplier = 1.25
+weight = 0.5
+""".strip(),
+ encoding="utf-8",
+ )
+
+ prompts = load_prompt_settings(prompt_file)
+
+ assert len(prompts) == 4
+
+ first = prompts[0]
+ second = prompts[1]
+ third = prompts[2]
+ fourth = prompts[3]
+
+ assert first.target == ""
+ assert first.positive == "low detail"
+ assert first.unconditional == "high detail"
+ assert first.action == "erase"
+ assert first.multiplier == 1.25
+ assert first.weight == 0.5
+ assert first.get_resolution() == (768, 768)
+
+ assert second.positive == "high detail"
+ assert second.unconditional == "low detail"
+ assert second.action == "enhance"
+ assert second.multiplier == 1.25
+
+ assert third.action == "erase"
+ assert third.multiplier == -1.25
+
+ assert fourth.action == "enhance"
+ assert fourth.multiplier == -1.25
+
+
+def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids():
+ from library import sdxl_train_util
+ from library.leco_train_util import PromptEmbedsXL, predict_noise_xl
+
+ class DummyScheduler:
+ def scale_model_input(self, latent_model_input, timestep):
+ return latent_model_input
+
+ class DummyUNet:
+ def __call__(self, x, timesteps, context, y):
+ self.x = x
+ self.timesteps = timesteps
+ self.context = context
+ self.y = y
+ return torch.zeros_like(x)
+
+ latents = torch.randn(1, 4, 8, 8)
+ prompt_embeds = PromptEmbedsXL(
+ text_embeds=torch.randn(2, 77, 2048),
+ pooled_embeds=torch.randn(2, 1280),
+ )
+ add_time_ids = torch.tensor(
+ [
+ [1024, 1024, 0, 0, 1024, 1024],
+ [1024, 1024, 0, 0, 1024, 1024],
+ ],
+ dtype=prompt_embeds.pooled_embeds.dtype,
+ )
+
+ unet = DummyUNet()
+ noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids)
+
+ expected_size_embeddings = sdxl_train_util.get_size_embeddings(
+ add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device
+ ).to(prompt_embeds.pooled_embeds.dtype)
+
+ assert noise_pred.shape == latents.shape
+ assert unet.context is prompt_embeds.text_embeds
+ assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1))
diff --git a/tests/test_sdxl_train_leco.py b/tests/test_sdxl_train_leco.py
new file mode 100644
index 00000000..637aa28f
--- /dev/null
+++ b/tests/test_sdxl_train_leco.py
@@ -0,0 +1,16 @@
+import sdxl_train_leco
+from library import deepspeed_utils, sdxl_train_util, train_util
+
+
+def test_syntax():
+ assert sdxl_train_leco is not None
+
+
+def test_setup_parser_supports_shared_training_validation():
+ args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
+
+ train_util.verify_training_args(args)
+ sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
+
+ assert args.min_snr_gamma is None
+ assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
diff --git a/tests/test_train_leco.py b/tests/test_train_leco.py
new file mode 100644
index 00000000..4a43d3d7
--- /dev/null
+++ b/tests/test_train_leco.py
@@ -0,0 +1,15 @@
+import train_leco
+from library import deepspeed_utils, train_util
+
+
+def test_syntax():
+ assert train_leco is not None
+
+
+def test_setup_parser_supports_shared_training_validation():
+ args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"])
+
+ train_util.verify_training_args(args)
+
+ assert args.min_snr_gamma is None
+ assert deepspeed_utils.prepare_deepspeed_plugin(args) is None
diff --git a/train_leco.py b/train_leco.py
new file mode 100644
index 00000000..e5439e0f
--- /dev/null
+++ b/train_leco.py
@@ -0,0 +1,319 @@
+import argparse
+import importlib
+import random
+
+import torch
+from accelerate.utils import set_seed
+from diffusers import DDPMScheduler
+from tqdm import tqdm
+
+from library.device_utils import init_ipex, clean_memory_on_device
+
+init_ipex()
+
+from library import custom_train_functions, strategy_sd, train_util
+from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training
+from library.leco_train_util import (
+ PromptEmbedsCache,
+ apply_noise_offset,
+ build_network_kwargs,
+ concat_embeddings,
+ diffusion,
+ encode_prompt_sd,
+ get_initial_latents,
+ get_random_resolution,
+ get_save_extension,
+ load_prompt_settings,
+ predict_noise,
+ save_weights,
+)
+from library.utils import add_logging_arguments, setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def setup_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(parser)
+ train_util.add_optimizer_arguments(parser)
+ train_util.add_training_arguments(parser, support_dreambooth=False)
+ custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False)
+ add_logging_arguments(parser)
+
+ parser.add_argument(
+ "--save_model_as",
+ type=str,
+ default="safetensors",
+ choices=[None, "ckpt", "pt", "safetensors"],
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない")
+
+ parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml")
+ parser.add_argument(
+ "--max_denoising_steps",
+ type=int,
+ default=40,
+ help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数",
+ )
+ parser.add_argument(
+ "--leco_denoise_guidance_scale",
+ type=float,
+ default=3.0,
+ help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale",
+ )
+
+ parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network")
+ parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train")
+ parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank")
+ parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha")
+ parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout")
+ parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments")
+ parser.add_argument(
+ "--network_train_text_encoder_only",
+ action="store_true",
+ help="unsupported for LECO; kept for compatibility / LECOでは未対応",
+ )
+ parser.add_argument(
+ "--network_train_unet_only",
+ action="store_true",
+ help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習",
+ )
+ parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata")
+ parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights")
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
+
+ # dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed)
+ parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS)
+ parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS)
+ parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS)
+
+ return parser
+
+
+def main():
+ parser = setup_parser()
+ args = parser.parse_args()
+ args = train_util.read_config_from_file(args, parser)
+ train_util.verify_training_args(args)
+
+ if args.output_dir is None:
+ raise ValueError("--output_dir is required")
+ if args.network_train_text_encoder_only:
+ raise ValueError("LECO does not support text encoder LoRA training")
+
+ if args.seed is None:
+ args.seed = random.randint(0, 2**32 - 1)
+ set_seed(args.seed)
+
+ accelerator = train_util.prepare_accelerator(args)
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
+
+ prompt_settings = load_prompt_settings(args.prompts_file)
+ logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}")
+
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
+ del vae
+
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
+ unet.requires_grad_(False)
+ unet.to(accelerator.device, dtype=weight_dtype)
+ unet.train()
+
+ tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
+ text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
+
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+
+ prompt_cache = PromptEmbedsCache()
+ unique_prompts = sorted(
+ {
+ prompt
+ for setting in prompt_settings
+ for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral)
+ }
+ )
+ with torch.no_grad():
+ for prompt in unique_prompts:
+ prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt)
+
+ text_encoder.to("cpu")
+ clean_memory_on_device(accelerator.device)
+
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ num_train_timesteps=1000,
+ clip_sample=False,
+ )
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
+ if args.zero_terminal_snr:
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
+
+ network_module = importlib.import_module(args.network_module)
+ net_kwargs = build_network_kwargs(args)
+ if args.dim_from_weights:
+ if args.network_weights is None:
+ raise ValueError("--dim_from_weights requires --network_weights")
+ network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs)
+ else:
+ network = network_module.create_network(
+ 1.0,
+ args.network_dim,
+ args.network_alpha,
+ None,
+ text_encoder,
+ unet,
+ neuron_dropout=args.network_dropout,
+ **net_kwargs,
+ )
+
+ network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True)
+ network.set_multiplier(0.0)
+
+ if args.network_weights is not None:
+ info = network.load_weights(args.network_weights)
+ logger.info(f"loaded network weights from {args.network_weights}: {info}")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ network.enable_gradient_checkpointing()
+
+ unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate
+ trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate)
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+
+ network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler)
+ accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
+
+ if args.full_fp16:
+ train_util.patch_accelerator_for_fp16_training(accelerator)
+
+ optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args)
+ optimizer_train_fn()
+ train_util.init_trackers(accelerator, args, "leco_train")
+
+ progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps")
+ global_step = 0
+
+ while global_step < args.max_train_steps:
+ with accelerator.accumulate(network):
+ optimizer.zero_grad(set_to_none=True)
+
+ setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()]
+ noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device)
+
+ timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item()
+ height, width = get_random_resolution(setting)
+
+ latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to(
+ accelerator.device, dtype=weight_dtype
+ )
+ latents = apply_noise_offset(latents, args.noise_offset)
+
+ network_multiplier = accelerator.unwrap_model(network)
+ network_multiplier.set_multiplier(setting.multiplier)
+ with accelerator.autocast():
+ denoised_latents = diffusion(
+ unet,
+ noise_scheduler,
+ latents,
+ concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
+ total_timesteps=timesteps_to,
+ guidance_scale=args.leco_denoise_guidance_scale,
+ )
+
+ noise_scheduler.set_timesteps(1000, device=accelerator.device)
+ current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps)
+ current_timestep = noise_scheduler.timesteps[current_timestep_index]
+
+ network_multiplier.set_multiplier(0.0)
+ with torch.no_grad(), accelerator.autocast():
+ positive_latents = predict_noise(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size),
+ guidance_scale=1.0,
+ )
+ neutral_latents = predict_noise(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size),
+ guidance_scale=1.0,
+ )
+ unconditional_latents = predict_noise(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size),
+ guidance_scale=1.0,
+ )
+
+ network_multiplier.set_multiplier(setting.multiplier)
+ with accelerator.autocast():
+ target_latents = predict_noise(
+ unet,
+ noise_scheduler,
+ current_timestep,
+ denoised_latents,
+ concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size),
+ guidance_scale=1.0,
+ )
+
+ target = setting.build_target(positive_latents, neutral_latents, unconditional_latents)
+ loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=(1, 2, 3))
+ if args.min_snr_gamma is not None and args.min_snr_gamma > 0:
+ timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long)
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
+ loss = loss.mean() * setting.weight
+
+ accelerator.backward(loss)
+
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
+ accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+
+ if accelerator.sync_gradients:
+ global_step += 1
+ progress_bar.update(1)
+ network_multiplier = accelerator.unwrap_model(network)
+ network_multiplier.set_multiplier(0.0)
+
+ logs = {
+ "loss": loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ "guidance_scale": setting.guidance_scale,
+ "network_multiplier": setting.multiplier,
+ }
+ accelerator.log(logs, step=global_step)
+ progress_bar.set_postfix(loss=f"{logs['loss']:.4f}")
+
+ if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps:
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False)
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()