mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 17:35:47 +00:00
Compare commits
11 Commits
feature-ch
...
d98400b06e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d98400b06e | ||
|
|
518545bffb | ||
|
|
d300f19045 | ||
|
|
aec7e16094 | ||
|
|
d53a532a82 | ||
|
|
3adbbb6e33 | ||
|
|
a7b33f3204 | ||
|
|
c0c36a4e2f | ||
|
|
25771a5180 | ||
|
|
e0fcb5152a | ||
|
|
13ccfc39f8 |
@@ -16,6 +16,10 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Jul 21, 2025:
|
||||
- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions.
|
||||
- Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details.
|
||||
|
||||
Jul 10, 2025:
|
||||
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
Status: reviewed
|
||||
|
||||
# LoRA Training Guide for Lumina Image 2.0 using `lumina_train_network.py` / `lumina_train_network.py` を用いたLumina Image 2.0モデルのLoRA学習ガイド
|
||||
|
||||
This document explains how to train LoRA (Low-Rank Adaptation) models for Lumina Image 2.0 using `lumina_train_network.py` in the `sd-scripts` repository.
|
||||
@@ -198,6 +196,7 @@ For Lumina Image 2.0, you can specify different dimensions for various component
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のLumina Image 2.0特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||
|
||||
#### モデル関連
|
||||
@@ -250,6 +249,18 @@ After setting the required arguments, run the command to begin training. The ove
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_lumina_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Lumina Image 2.0, such as ComfyUI with appropriate nodes.
|
||||
|
||||
### Inference with scripts in this repository / このリポジトリのスクリプトを使用した推論
|
||||
|
||||
The inference script is also available. The script is `lumina_minimal_inference.py`. See `--help` for options.
|
||||
|
||||
```
|
||||
python lumina_minimal_inference.py --pretrained_model_name_or_path path/to/lumina.safetensors --gemma2_path path/to/gemma.safetensors" --ae_path path/to/flux_ae.safetensors --output_dir path/to/output_dir --offload --seed 1234 --prompt "Positive prompt" --system_prompt "You are an assistant designed to generate high-quality images based on user prompts." --negative_prompt "negative prompt"
|
||||
```
|
||||
|
||||
`--add_system_prompt_to_negative_prompt` option can be used to add the system prompt to the negative prompt.
|
||||
|
||||
`--lora_weights` option can be used to specify the LoRA weights file, and optional multiplier (like `path;1.0`).
|
||||
|
||||
## 6. Others / その他
|
||||
|
||||
`lumina_train_network.py` shares many features with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these, see the [train_network.py guide](train_network.md#5-other-features--その他の機能) or run `python lumina_train_network.py --help`.
|
||||
@@ -279,6 +290,8 @@ Sample prompts can include CFG truncate (`--ctr`) and Renorm CFG (`-rcfg`) param
|
||||
|
||||
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_lumina_lora.safetensors`)が保存されます。このファイルは、Lumina Image 2.0モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||
|
||||
当リポジトリ内の推論スクリプトを用いて推論することも可能です。スクリプトは`lumina_minimal_inference.py`です。オプションは`--help`で確認できます。記述例は英語版のドキュメントをご確認ください。
|
||||
|
||||
`lumina_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python lumina_train_network.py --help`) を参照してください。
|
||||
|
||||
### 6.1. 推奨設定
|
||||
|
||||
@@ -44,10 +44,21 @@ def load_lumina_model(
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
|
||||
# remove "model.diffusion_model." prefix
|
||||
filtered_state_dict = {
|
||||
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
|
||||
}
|
||||
state_dict = filtered_state_dict
|
||||
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
return model
|
||||
@@ -78,6 +89,13 @@ def load_ae(
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "vae.decoder.conv_in.bias" in sd:
|
||||
# remove "vae." prefix
|
||||
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
|
||||
sd = filtered_sd
|
||||
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
@@ -152,6 +170,16 @@ def load_gemma2(
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "text_encoders.gemma2_2b.logit_scale" in sd:
|
||||
# remove "text_encoders.gemma2_2b.transformer.model." prefix
|
||||
filtered_sd = {
|
||||
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
|
||||
for k, v in sd.items()
|
||||
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
|
||||
}
|
||||
sd = filtered_sd
|
||||
|
||||
info = gemma2.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
@@ -173,7 +201,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
|
||||
# Embedding layers
|
||||
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
|
||||
@@ -211,11 +238,11 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
|
||||
|
||||
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
# Handle block-specific patterns
|
||||
if '().' in diff_key:
|
||||
if "()." in diff_key:
|
||||
for block_idx in range(num_double_blocks):
|
||||
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
|
||||
block_diff_key = diff_key.replace('().', f'{block_idx}.')
|
||||
|
||||
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
|
||||
block_diff_key = diff_key.replace("().", f"{block_idx}.")
|
||||
|
||||
# Search for and convert block-specific keys
|
||||
for input_key, value in list(sd.items()):
|
||||
if input_key == block_diff_key:
|
||||
@@ -228,6 +255,5 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
|
||||
else:
|
||||
print(f"Not found: {diff_key}")
|
||||
|
||||
|
||||
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
|
||||
return new_sd
|
||||
|
||||
@@ -6010,6 +6010,9 @@ def get_noise_noisy_latents_and_timesteps(
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise
|
||||
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
|
||||
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def generate_image(
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
negative_prompt: Optional[str],
|
||||
args,
|
||||
args: argparse.Namespace,
|
||||
cfg_trunc_ratio: float = 0.25,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
@@ -88,7 +88,9 @@ def generate_image(
|
||||
with torch.no_grad():
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(
|
||||
negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt
|
||||
)
|
||||
with torch.no_grad():
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)
|
||||
|
||||
@@ -158,7 +160,7 @@ def generate_image(
|
||||
# 5. Decode latents
|
||||
#
|
||||
logger.info("Decoding image...")
|
||||
latents = latents / ae.scale_factor + ae.shift_factor
|
||||
# latents = latents / ae.scale_factor + ae.shift_factor
|
||||
with torch.no_grad():
|
||||
image = ae.decode(latents.to(ae_dtype))
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
@@ -215,6 +217,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
|
||||
parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
|
||||
parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
|
||||
parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt")
|
||||
parser.add_argument(
|
||||
"--gemma2_max_token_length",
|
||||
type=int,
|
||||
@@ -231,13 +234,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--cfg_trunc_ratio",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="TBD",
|
||||
help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--renorm_cfg",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="TBD",
|
||||
help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
|
||||
@@ -294,7 +294,7 @@ def train(args):
|
||||
# load lumina
|
||||
nextdit = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
weight_dtype,
|
||||
torch.device("cpu"),
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
use_flash_attn=args.use_flash_attn,
|
||||
@@ -494,6 +494,8 @@ def train(args):
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
@@ -739,7 +741,7 @@ def train(args):
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = nextdit(
|
||||
x=img, # image latents (B, C, H, W)
|
||||
x=noisy_model_input, # image latents (B, C, H, W)
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask.to(
|
||||
@@ -751,8 +753,8 @@ def train(args):
|
||||
args, model_pred, noisy_model_input, sigmas
|
||||
)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
# flow matching loss
|
||||
target = latents - noise
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(
|
||||
|
||||
Reference in New Issue
Block a user