From a96d684ffab11d6f40a8f1dde3c8103ab1d2bd27 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Jul 2025 20:44:43 +0900 Subject: [PATCH 01/15] feat: add Chroma model implementation --- library/chroma_models.py | 706 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 706 insertions(+) create mode 100644 library/chroma_models.py diff --git a/library/chroma_models.py b/library/chroma_models.py new file mode 100644 index 00000000..9f21afad --- /dev/null +++ b/library/chroma_models.py @@ -0,0 +1,706 @@ +# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py +# and modified +# licensed under Apache License 2.0 + +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn +import torch.nn.functional as F +import torch.utils.checkpoint as ckpt + +from .flux_models import ( + attention, + rope, + apply_rope, + EmbedND, + timestep_embedding, + MLPEmbedder, + RMSNorm, + QKNorm, + SelfAttention +) +from . import custom_offloading_utils + + +def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks): + """ + Distributes slices of the tensor into the block_dict as ModulationOut objects. + + Args: + tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim]. + """ + batch_size, vectors, dim = tensor.shape + + block_dict = {} + + # HARD CODED VALUES! lookup table for the generated vectors + # TODO: move this into chroma config! + # Add 38 single mod blocks + for i in range(depth_single_blocks): + key = f"single_blocks.{i}.modulation.lin" + block_dict[key] = None + + # Add 19 image double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.img_mod.lin" + block_dict[key] = None + + # Add 19 text double blocks + for i in range(depth_double_blocks): + key = f"double_blocks.{i}.txt_mod.lin" + block_dict[key] = None + + # Add the final layer + block_dict["final_layer.adaLN_modulation.1"] = None + # 6.2b version + # block_dict["lite_double_blocks.4.img_mod.lin"] = None + # block_dict["lite_double_blocks.4.txt_mod.lin"] = None + + idx = 0 # Index to keep track of the vector slices + + for key in block_dict.keys(): + if "single_blocks" in key: + # Single block: 1 ModulationOut + block_dict[key] = ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + idx += 3 # Advance by 3 vectors + + elif "img_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "txt_mod" in key: + # Double block: List of 2 ModulationOut + double_block = [] + for _ in range(2): # Create 2 ModulationOut objects + double_block.append( + ModulationOut( + shift=tensor[:, idx : idx + 1, :], + scale=tensor[:, idx + 1 : idx + 2, :], + gate=tensor[:, idx + 2 : idx + 3, :], + ) + ) + idx += 3 # Advance by 3 vectors per ModulationOut + block_dict[key] = double_block + + elif "final_layer" in key: + # Final layer: 1 ModulationOut + block_dict[key] = [ + tensor[:, idx : idx + 1, :], + tensor[:, idx + 1 : idx + 2, :], + ] + idx += 2 # Advance by 3 vectors + + return block_dict + + +class Approximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self): + for layer in self.layers: + layer.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + for layer in self.layers: + layer.disable_gradient_checkpointing() + + def forward(self, x: Tensor) -> Tensor: + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + x = self.out_proj(x) + + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +def _modulation_shift_scale_fn(x, scale, shift): + return (1 + scale) * x + shift + + +def _modulation_gate_fn(x, gate, gate_params): + return x + gate * gate_params + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool = False, + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.gradient_checkpointing = False + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + return _modulation_gate_fn(x, gate, gate_params) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + distill_vec: list[ModulationOut], + mask: Tensor, + ) -> tuple[Tensor, Tensor]: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec + + # prepare image for attention + img_modulated = self.img_norm1(img) + # replaced with compiled fn + # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift) + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + # replaced with compiled fn + # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift) + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe, mask=mask) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + # replaced with compiled fn + # img = img + img_mod1.gate * self.img_attn.proj(img_attn) + # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + img = self.modulation_gate_fn( + img, + img_mod2.gate, + self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)), + ) + + # calculate the txt bloks + # replaced with compiled fn + # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + txt = self.modulation_gate_fn( + txt, + txt_mod2.gate, + self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)), + ) + + return img, txt + + def forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + distill_vec: list[ModulationOut], + mask: Tensor, + ) -> tuple[Tensor, Tensor]: + if self.training and self.gradient_checkpointing: + return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, mask, use_reentrant=False) + else: + return self._forward(img, txt, pe, distill_vec, mask) + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + + self.gradient_checkpointing = False + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def modulation_gate_fn(self, x, gate, gate_params): + return _modulation_gate_fn(x, gate, gate_params) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def _forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + mod = distill_vec + # replaced with compiled fn + # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) + qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, mask=mask) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + # replaced with compiled fn + # return x + mod.gate * output + return self.modulation_gate_fn(x, mod.gate, output) + + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + if self.training and self.gradient_checkpointing: + return ckpt.checkpoint(self._forward, x, pe, distill_vec, mask, use_reentrant=False) + else: + return self._forward(x, pe, distill_vec, mask) + + +class LastLayer(nn.Module): + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + ): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def modulation_shift_scale_fn(self, x, scale, shift): + return _modulation_shift_scale_fn(x, scale, shift) + + def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor: + shift, scale = distill_vec + shift = shift.squeeze(1) + scale = scale.squeeze(1) + # replaced with compiled fn + # x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :]) + x = self.linear(x) + return x + + +@dataclass +class ChromaParams: + in_channels: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + approximator_in_dim: int + approximator_depth: int + approximator_hidden_size: int + _use_compiled: bool + + +chroma_params = ChromaParams( + in_channels=64, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + approximator_in_dim=64, + approximator_depth=5, + approximator_hidden_size=5120, + _use_compiled=False, +) + + +def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): + """ + Modifies attention mask to allow attention to a few extra padding tokens. + + Args: + mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens) + max_seq_length: Maximum sequence length of the model + num_extra_padding: Number of padding tokens to unmask + + Returns: + Modified mask + """ + # Get the actual sequence length from the mask + seq_length = mask.sum(dim=-1) + batch_size = mask.shape[0] + + modified_mask = mask.clone() + + for i in range(batch_size): + current_seq_len = int(seq_length[i].item()) + + # Only add extra padding tokens if there's room + if current_seq_len < max_seq_length: + # Calculate how many padding tokens we can unmask + available_padding = max_seq_length - current_seq_len + tokens_to_unmask = min(num_extra_padding, available_padding) + + # Unmask the specified number of padding tokens right after the sequence + modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1 + + return modified_mask + + +class Chroma(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + + def __init__(self, params: ChromaParams): + super().__init__() + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + + # TODO: need proper mapping for this approximator output! + # currently the mapping is hardcoded in distribute_modulations function + self.distilled_guidance_layer = Approximator( + params.approximator_in_dim, + self.hidden_size, + params.approximator_hidden_size, + params.approximator_depth, + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer( + self.hidden_size, + 1, + self.out_channels, + ) + + # TODO: move this hardcoded value to config + # single layer has 3 modulation vectors + # double layer has 6 modulation vectors for each expert + # final layer has 2 modulation vectors + self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2 + self.depth_single_blocks = params.depth_single_blocks + self.depth_double_blocks = params.depth + # self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0) + self.register_buffer( + "mod_index", + torch.tensor(list(range(self.mod_index_length)), device="cpu"), + persistent=False, + ) + self.approximator_in_dim = params.approximator_in_dim + + self.blocks_to_swap = None + self.offloader_double = None + self.offloader_single = None + self.num_double_blocks = len(self.double_blocks) + self.num_single_blocks = len(self.single_blocks) + + @property + def device(self): + # Get the device of the module (assumes all parameters are on the same device) + return next(self.parameters()).device + + def enable_gradient_checkpointing(self): + self.distilled_guidance_layer.enable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + def disable_gradient_checkpointing(self): + self.distilled_guidance_layer.disable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + def enable_block_swap(self, num_blocks: int, device: torch.device): + self.blocks_to_swap = num_blocks + double_blocks_to_swap = num_blocks // 2 + single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + + assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( + f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " + f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." + ) + + self.offloader_double = custom_offloading_utils.ModelOffloader( + self.double_blocks, double_blocks_to_swap, device + ) + self.offloader_single = custom_offloading_utils.ModelOffloader( + self.single_blocks, single_blocks_to_swap, device + ) + print( + f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." + ) + + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu. do not move blocks to device to reduce temporary memory usage + if self.blocks_to_swap: + save_double_blocks = self.double_blocks + save_single_blocks = self.single_blocks + self.double_blocks = None + self.single_blocks = None + + self.to(device) + + if self.blocks_to_swap: + self.double_blocks = save_double_blocks + self.single_blocks = save_single_blocks + + def prepare_block_swap_before_forward(self): + if self.blocks_to_swap is None or self.blocks_to_swap == 0: + return + self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + txt_mask: Tensor, + timesteps: Tensor, + guidance: Tensor, + attn_padding: int = 1, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + txt = self.txt_in(txt) + + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = ( + torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) + ) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + # compute mask + # assume max seq length from the batched input + + max_len = txt.shape[1] + + # mask + with torch.no_grad(): + txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding) + txt_img_mask = torch.cat( + [ + txt_mask_w_padding, + torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + ], + dim=1, + ) + txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() + txt_img_mask = txt_img_mask[None, None, ...].repeat(txt.shape[0], self.num_heads, 1, 1).int().bool() + # txt_mask_w_padding[txt_mask_w_padding==False] = True + + if not self.blocks_to_swap: + for i, block in enumerate(self.double_blocks): + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) + else: + for i, block in enumerate(self.double_blocks): + self.offloader_double.wait_for_block(i) + + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] + + img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) + + self.offloader_double.submit_move_blocks(self.double_blocks, i) + + img = torch.cat((txt, img), 1) + if not self.blocks_to_swap: + for i, block in enumerate(self.single_blocks): + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + else: + for i, block in enumerate(self.single_blocks): + self.offloader_single.wait_for_block(i) + + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + + self.offloader_single.submit_move_blocks(self.single_blocks, i) + img = img[:, txt.shape[1] :, ...] + final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] + img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels) + return img From 24d2ea86c70482ec062412e4214ae221a22cd0a0 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 20 Jul 2025 12:56:42 +0900 Subject: [PATCH 02/15] feat: support Chroma model in loading and inference processes --- flux_minimal_inference.py | 49 +++++++++++------ flux_train.py | 4 +- flux_train_control_net.py | 4 +- flux_train_network.py | 4 +- library/chroma_models.py | 85 +++++------------------------ library/flux_utils.py | 110 +++++++++++++++++++++++++------------- 6 files changed, 123 insertions(+), 133 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1..a7bff74d 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -108,12 +108,18 @@ def denoise( else: b_img = img + # For Chroma model, y might be None, so create dummy tensor + if b_vec is None: + y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor + else: + y_input = b_vec + pred = model( img=b_img, img_ids=b_img_ids, txt=b_txt, txt_ids=b_txt_ids, - y=b_vec, + y=y_input, timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=b_t5_attn_mask, @@ -134,7 +140,7 @@ def do_sample( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - l_pooled: torch.Tensor, + l_pooled: Optional[torch.Tensor], t5_out: torch.Tensor, txt_ids: torch.Tensor, num_steps: int, @@ -192,7 +198,7 @@ def do_sample( def generate_image( model, - clip_l: CLIPTextModel, + clip_l: Optional[CLIPTextModel], t5xxl, ae, prompt: str, @@ -231,7 +237,7 @@ def generate_image( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) # prepare fp8 models - if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") clip_l.to(clip_l_dtype) # fp8 clip_l.text_model.embeddings.to(dtype=torch.bfloat16) @@ -267,18 +273,22 @@ def generate_image( # prepare embeddings logger.info("Encoding prompts...") - clip_l = clip_l.to(device) + if clip_l is not None: + clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) def encode(prpt: str): tokens_and_masks = tokenize_strategy.tokenize(prpt) with torch.no_grad(): - if is_fp8(clip_l_dtype): - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + if clip_l is not None: + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + l_pooled = None if is_fp8(t5xxl_dtype): with accelerator.autocast(): @@ -288,7 +298,7 @@ def generate_image( else: with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask ) return l_pooled, t5_out, txt_ids, t5_attn_mask @@ -305,7 +315,8 @@ def generate_image( raise ValueError("NaN in t5_out") if args.offload: - clip_l = clip_l.cpu() + if clip_l is not None: + clip_l = clip_l.cpu() t5xxl = t5xxl.cpu() # del clip_l, t5xxl device_utils.clean_memory() @@ -385,6 +396,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use") parser.add_argument("--clip_l", type=str, required=False) parser.add_argument("--t5xxl", type=str, required=False) parser.add_argument("--ae", type=str, required=False) @@ -438,10 +450,13 @@ if __name__ == "__main__": else: accelerator = None - # load clip_l - logger.info(f"Loading clip_l from {args.clip_l}...") - clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) - clip_l.eval() + # load clip_l (skip for chroma model) + if args.model_type == "flux": + logger.info(f"Loading clip_l from {args.clip_l}...") + clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device) + clip_l.eval() + else: + clip_l = None logger.info(f"Loading t5xxl from {args.t5xxl}...") t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) @@ -453,7 +468,7 @@ if __name__ == "__main__": # t5xxl = accelerator.prepare(t5xxl) # DiT - is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device) + model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train.py b/flux_train.py index 6f98adea..1d2cc68b 100644 --- a/flux_train.py +++ b/flux_train.py @@ -270,8 +270,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - _, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + model_type, _, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) if args.gradient_checkpointing: diff --git a/flux_train_control_net.py b/flux_train_control_net.py index cecd0001..3c038c32 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -258,8 +258,8 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - is_schnell, flux = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors + model_type, is_schnell, flux = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) flux.requires_grad_(False) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..b2bf8e7c 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -95,8 +95,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors + self.model_type, self.is_schnell, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux" ) if args.fp8_base: # check dtype of model diff --git a/library/chroma_models.py b/library/chroma_models.py index 9f21afad..e1da751b 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -11,17 +11,7 @@ from torch import Tensor, nn import torch.nn.functional as F import torch.utils.checkpoint as ckpt -from .flux_models import ( - attention, - rope, - apply_rope, - EmbedND, - timestep_embedding, - MLPEmbedder, - RMSNorm, - QKNorm, - SelfAttention -) +from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux from . import custom_offloading_utils @@ -468,13 +458,13 @@ def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8): return modified_mask -class Chroma(nn.Module): +class Chroma(Flux): """ Transformer model for flow matching on sequences. """ def __init__(self, params: ChromaParams): - super().__init__() + nn.Module.__init__(self) self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels @@ -548,60 +538,9 @@ class Chroma(nn.Module): self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) - @property - def device(self): - # Get the device of the module (assumes all parameters are on the same device) - return next(self.parameters()).device - - def enable_gradient_checkpointing(self): - self.distilled_guidance_layer.enable_gradient_checkpointing() - for block in self.double_blocks + self.single_blocks: - block.enable_gradient_checkpointing() - - def disable_gradient_checkpointing(self): - self.distilled_guidance_layer.disable_gradient_checkpointing() - for block in self.double_blocks + self.single_blocks: - block.disable_gradient_checkpointing() - - def enable_block_swap(self, num_blocks: int, device: torch.device): - self.blocks_to_swap = num_blocks - double_blocks_to_swap = num_blocks // 2 - single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 - - assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, ( - f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. " - f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks." - ) - - self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, double_blocks_to_swap, device - ) - self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, single_blocks_to_swap, device - ) - print( - f"Chroma: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." - ) - - def move_to_device_except_swap_blocks(self, device: torch.device): - # assume model is on cpu. do not move blocks to device to reduce temporary memory usage - if self.blocks_to_swap: - save_double_blocks = self.double_blocks - save_single_blocks = self.single_blocks - self.double_blocks = None - self.single_blocks = None - - self.to(device) - - if self.blocks_to_swap: - self.double_blocks = save_double_blocks - self.single_blocks = save_single_blocks - - def prepare_block_swap_before_forward(self): - if self.blocks_to_swap is None or self.blocks_to_swap == 0: - return - self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + # Initialize properties required by Flux parent class + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False def forward( self, @@ -609,10 +548,12 @@ class Chroma(nn.Module): img_ids: Tensor, txt: Tensor, txt_ids: Tensor, - txt_mask: Tensor, timesteps: Tensor, - guidance: Tensor, - attn_padding: int = 1, + y: Tensor, + block_controlnet_hidden_states=None, + block_controlnet_single_hidden_states=None, + guidance: Tensor | None = None, + txt_attention_mask: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -654,11 +595,11 @@ class Chroma(nn.Module): # mask with torch.no_grad(): - txt_mask_w_padding = modify_mask_to_attend_padding(txt_mask, max_len, attn_padding) + txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1) txt_img_mask = torch.cat( [ txt_mask_w_padding, - torch.ones([img.shape[0], img.shape[1]], device=txt_mask.device), + torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device), ], dim=1, ) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..a5cfcdff 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -92,50 +92,84 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int def load_flow_model( - ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False -) -> Tuple[bool, flux_models.Flux]: - is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + model_type: str = "flux", +) -> Tuple[str, bool, flux_models.Flux]: + if model_type == "flux": + is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - # build model - logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") - with torch.device("meta"): - params = flux_models.configs[name].params + # build model + logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") + with torch.device("meta"): + params = flux_models.configs[name].params - # set the number of blocks - if params.depth != num_double_blocks: - logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") - params = replace(params, depth=num_double_blocks) - if params.depth_single_blocks != num_single_blocks: - logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") - params = replace(params, depth_single_blocks=num_single_blocks) + # set the number of blocks + if params.depth != num_double_blocks: + logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}") + params = replace(params, depth=num_double_blocks) + if params.depth_single_blocks != num_single_blocks: + logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}") + params = replace(params, depth_single_blocks=num_single_blocks) - model = flux_models.Flux(params) - if dtype is not None: - model = model.to(dtype) + model = flux_models.Flux(params) + if dtype is not None: + model = model.to(dtype) - # load_sft doesn't support torch.device - logger.info(f"Loading state dict from {ckpt_path}") - sd = {} - for ckpt_path in ckpt_paths: - sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = {} + for ckpt_path in ckpt_paths: + sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)) - # convert Diffusers to BFL - if is_diffusers: - logger.info("Converting Diffusers to BFL") - sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) - logger.info("Converted Diffusers to BFL") + # convert Diffusers to BFL + if is_diffusers: + logger.info("Converting Diffusers to BFL") + sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) + logger.info("Converted Diffusers to BFL") - # if the key has annoying prefix, remove it - for key in list(sd.keys()): - new_key = key.replace("model.diffusion_model.", "") - if new_key == key: - break # the model doesn't have annoying prefix - sd[new_key] = sd.pop(key) + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) - info = model.load_state_dict(sd, strict=False, assign=True) - logger.info(f"Loaded Flux: {info}") - return is_schnell, model + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Flux: {info}") + return model_type, is_schnell, model + + elif model_type == "chroma": + from . import chroma_models + + # build model + logger.info("Building Chroma model from BFL checkpoint") + with torch.device("meta"): + model = chroma_models.Chroma(chroma_models.chroma_params) + if dtype is not None: + model = model.to(dtype) + + # load_sft doesn't support torch.device + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + + info = model.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded Chroma: {info}") + is_schnell = False # Chroma is not schnell + return model_type, is_schnell, model + + else: + raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") def load_ae( @@ -166,7 +200,7 @@ def load_controlnet( sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = controlnet.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded ControlNet: {info}") - return controlnet + return controlnet def load_clip_l( From 404ddb060d04285d72ffff9342542eec71d9c352 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 14:08:54 +0900 Subject: [PATCH 03/15] fix: inference for Chroma model --- flux_minimal_inference.py | 30 +++++++++++++++--------------- library/chroma_models.py | 9 +++++++-- library/flux_utils.py | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index a7bff74d..550904d2 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -78,16 +78,19 @@ def denoise( neg_t5_attn_mask: Optional[torch.Tensor] = None, cfg_scale: Optional[float] = None, ): - # this is ignored for schnell - logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") - guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - # prepare classifier free guidance - if neg_txt is not None and neg_vec is not None: + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") + do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0) + + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype) + + if do_cfg: + print("Using classifier free guidance") b_img_ids = torch.cat([img_ids, img_ids], dim=0) b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) b_txt = torch.cat([neg_txt, txt], dim=0) - b_vec = torch.cat([neg_vec, vec], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None if t5_attn_mask is not None and neg_t5_attn_mask is not None: b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) else: @@ -103,17 +106,13 @@ def denoise( t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) # classifier free guidance - if neg_txt is not None and neg_vec is not None: + if do_cfg: b_img = torch.cat([img, img], dim=0) else: b_img = img - # For Chroma model, y might be None, so create dummy tensor - if b_vec is None: - y_input = torch.zeros_like(b_txt[:, :1, :]) # dummy tensor - else: - y_input = b_vec - + y_input = b_vec + pred = model( img=b_img, img_ids=b_img_ids, @@ -126,7 +125,7 @@ def denoise( ) # classifier free guidance - if neg_txt is not None and neg_vec is not None: + if do_cfg: pred_uncond, pred = torch.chunk(pred, 2, dim=0) pred = pred_uncond + cfg_scale * (pred - pred_uncond) @@ -309,7 +308,7 @@ def generate_image( neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check - if torch.isnan(l_pooled).any(): + if l_pooled is not None and torch.isnan(l_pooled).any(): raise ValueError("NaN in l_pooled") if torch.isnan(t5_out).any(): raise ValueError("NaN in t5_out") @@ -329,6 +328,7 @@ def generate_image( img_ids = img_ids.to(device) t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None + neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None x = do_sample( accelerator, diff --git a/library/chroma_models.py b/library/chroma_models.py index e1da751b..f725db87 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module): k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe, mask=mask) + attn = attention(q, k, v, pe=pe, attn_mask=mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks @@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module): q, k = self.norm(q, k, v) # compute attention - attn = attention(q, k, v, pe=pe, mask=mask) + attn = attention(q, k, v, pe=pe, attn_mask=mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) # replaced with compiled fn @@ -555,6 +555,11 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, ) -> Tensor: + # print( + # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" + # ) + # print(f"timesteps: {timesteps}, guidance: {guidance}") + if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") diff --git a/library/flux_utils.py b/library/flux_utils.py index a5cfcdff..dda7c789 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -146,7 +146,7 @@ def load_flow_model( from . import chroma_models # build model - logger.info("Building Chroma model from BFL checkpoint") + logger.info("Building Chroma model") with torch.device("meta"): model = chroma_models.Chroma(chroma_models.chroma_params) if dtype is not None: From 8fd0b12d1f8bcae52cb11f0ccd193d8382b06166 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 16:00:58 +0900 Subject: [PATCH 04/15] feat: update DoubleStreamBlock and SingleStreamBlock to handle text sequence lengths instead of mask --- library/chroma_models.py | 232 ++++++++++++++++++++++++++------------- 1 file changed, 154 insertions(+), 78 deletions(-) diff --git a/library/chroma_models.py b/library/chroma_models.py index f725db87..06822a37 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -211,9 +211,9 @@ class DoubleStreamBlock(nn.Module): self, img: Tensor, txt: Tensor, - pe: Tensor, + pe: list[Tensor], distill_vec: list[ModulationOut], - mask: Tensor, + txt_seq_len: Tensor, ) -> tuple[Tensor, Tensor]: (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec @@ -235,13 +235,58 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - # run actual attention - q = torch.cat((txt_q, img_q), dim=2) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) + # run actual attention: we split the batch into each element + max_txt_len = txt_q.shape[-2] # max 512 + txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors + txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0)) + txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0)) + img_q = list(torch.chunk(img_q, img_q.shape[0], dim=0)) + img_k = list(torch.chunk(img_k, img_k.shape[0], dim=0)) + img_v = list(torch.chunk(img_v, img_v.shape[0], dim=0)) + txt_attn = [] + img_attn = [] + for i in range(txt.shape[0]): + print(i) + print(f"len(txt_q) = {len(txt_q)}, len(img_q) = {len(img_q)}, txt_seq_len.shape = {txt_seq_len.shape}") + print(f"txt_seq_len[i] = {txt_seq_len[i]}, txt_q.shape = {txt_q[i].shape}, img_q.shape = {img_q[i].shape}") + txt_q_i = txt_q[i][:, :, : txt_seq_len[i]] + txt_q[i] = None + img_q_i = img_q[i] + img_q[i] = None + q = torch.cat((txt_q_i, img_q_i), dim=2) + del txt_q_i, img_q_i - attn = attention(q, k, v, pe=pe, attn_mask=mask) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + txt_k_i = txt_k[i][:, :, : txt_seq_len[i]] + txt_k[i] = None + img_k_i = img_k[i] + img_k[i] = None + k = torch.cat((txt_k_i, img_k_i), dim=2) + del txt_k_i, img_k_i + + txt_v_i = txt_v[i][:, :, : txt_seq_len[i]] + txt_v[i] = None + img_v_i = img_v[i] + img_v[i] = None + v = torch.cat((txt_v_i, img_v_i), dim=2) + del txt_v_i, img_v_i + + attn = attention(q, k, v, pe=pe[i], attn_mask=None) # (1, L, D) + print(f"attn.shape = {attn.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}") + txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) + txt_attn_i[:, : txt_seq_len[i], :] = attn[:, : txt_seq_len[i], :] + img_attn_i = attn[:, txt_seq_len[i] :, :] + txt_attn.append(txt_attn_i) + img_attn.append(img_attn_i) + + txt_attn = torch.cat(txt_attn, dim=0) + img_attn = torch.cat(img_attn, dim=0) + + # q = torch.cat((txt_q, img_q), dim=2) + # k = torch.cat((txt_k, img_k), dim=2) + # v = torch.cat((txt_v, img_v), dim=2) + + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + # txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks # replaced with compiled fn @@ -273,12 +318,12 @@ class DoubleStreamBlock(nn.Module): txt: Tensor, pe: Tensor, distill_vec: list[ModulationOut], - mask: Tensor, + txt_seq_len: Tensor, ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, mask, use_reentrant=False) + return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, txt_seq_len, use_reentrant=False) else: - return self._forward(img, txt, pe, distill_vec, mask) + return self._forward(img, txt, pe, distill_vec, txt_seq_len) class SingleStreamBlock(nn.Module): @@ -332,7 +377,9 @@ class SingleStreamBlock(nn.Module): def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - def _forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + def _forward( + self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int + ) -> Tensor: mod = distill_vec # replaced with compiled fn # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift @@ -342,19 +389,44 @@ class SingleStreamBlock(nn.Module): q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) - # compute attention - attn = attention(q, k, v, pe=pe, attn_mask=mask) + # # compute attention + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + + # compute attention: we split the batch into each element + q = list(torch.chunk(q, q.shape[0], dim=0)) + k = list(torch.chunk(k, k.shape[0], dim=0)) + v = list(torch.chunk(v, v.shape[0], dim=0)) + attn = [] + for i in range(x.size(0)): + q_i = torch.cat((q[i][:, :, : txt_seq_len[i]], q[i][:, :, max_txt_len:]), dim=2) + q[i] = None + k_i = torch.cat((k[i][:, :, : txt_seq_len[i]], k[i][:, :, max_txt_len:]), dim=2) + k[i] = None + v_i = torch.cat((v[i][:, :, : txt_seq_len[i]], v[i][:, :, max_txt_len:]), dim=2) + v[i] = None + attn_trimmed = attention(q_i, k_i, v_i, pe=pe[i], attn_mask=None) + print( + f"attn_trimmed.shape = {attn_trimmed.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}, x.shape = {x.shape}" + ) + + attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) + attn_i[:, : txt_seq_len[i], :] = attn_trimmed[:, : txt_seq_len[i], :] + attn_i[:, max_txt_len:, :] = attn_trimmed[:, txt_seq_len[i] :, :] + attn.append(attn_i) + + attn = torch.cat(attn, dim=0) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) # replaced with compiled fn # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) - def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int) -> Tensor: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, x, pe, distill_vec, mask, use_reentrant=False) + return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, max_txt_len, use_reentrant=False) else: - return self._forward(x, pe, distill_vec, mask) + return self._forward(x, pe, distill_vec, txt_seq_len, max_txt_len) class LastLayer(nn.Module): @@ -542,6 +614,29 @@ class Chroma(Flux): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False + def get_mod_vectors( + self, + timesteps: Tensor, + guidance: Tensor | None = None, + batch_size: int | None = None, + requires_grad: bool = False, + ) -> Tensor: + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(batch_size, 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + if requires_grad: + input_vec = input_vec.requires_grad_(True) + mod_vectors = self.distilled_guidance_layer(input_vec) + return mod_vectors + def forward( self, img: Tensor, @@ -554,6 +649,8 @@ class Chroma(Flux): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, + attn_padding: int = 1, + mod_vectors: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -567,85 +664,64 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - # TODO: - # need to fix grad accumulation issue here for now it's in no grad mode - # besides, i don't want to wash out the PFP that's trained on this model weights anyway - # the fan out operation here is deleting the backward graph - # alternatively doing forward pass for every block manually is doable but slow - # custom backward probably be better - with torch.no_grad(): - distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) - # TODO: need to add toggle to omit this from schnell but that's not a priority - distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) - # get all modulation index - modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) - # we need to broadcast the modulation index here so each batch has all of the index - modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) - # and we need to broadcast timestep and guidance along too - timestep_guidance = ( - torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) - ) - # then and only then we could concatenate it together - input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + if mod_vectors is None: + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + # kohya-ss: I'm not sure why requires_grad is set to True here + mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 - # compute mask - # assume max seq length from the batched input + # calculate text length for each batch instead of masking + txt_emb_len = txt.shape[1] + txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) + txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) + max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch - max_len = txt.shape[1] + # trim txt embedding to the text length + txt = txt[:, :max_txt_len, :] - # mask - with torch.no_grad(): - txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1) - txt_img_mask = torch.cat( - [ - txt_mask_w_padding, - torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device), - ], - dim=1, - ) - txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() - txt_img_mask = txt_img_mask[None, None, ...].repeat(txt.shape[0], self.num_heads, 1, 1).int().bool() - # txt_mask_w_padding[txt_mask_w_padding==False] = True + # split positional encoding into each element of the batch, and trim masked tokens + print(f"pe shape = {pe.shape} dtype = {pe.dtype}, txt_seq_len = {txt_seq_len}") + pe = list(torch.chunk(pe, pe.shape[0], dim=0)) + for i in range(len(pe)): + # trim positional encoding to the text length + pe[i] = torch.cat([pe[i][:, :, : txt_seq_len[i]], pe[i][:, :, txt_emb_len:]], dim=2) - if not self.blocks_to_swap: - for i, block in enumerate(self.double_blocks): - # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] - - img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) - else: - for i, block in enumerate(self.double_blocks): + for i, block in enumerate(self.double_blocks): + if self.blocks_to_swap: self.offloader_double.wait_for_block(i) - # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] - img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) + img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len) + if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) img = torch.cat((txt, img), 1) - if not self.blocks_to_swap: - for i, block in enumerate(self.single_blocks): - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) - else: - for i, block in enumerate(self.single_blocks): + + for i, block in enumerate(self.single_blocks): + if self.blocks_to_swap: self.offloader_single.wait_for_block(i) - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len, max_txt_len=max_txt_len) + if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) + img = img[:, txt.shape[1] :, ...] final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels) From c4958b5dca0102b3f18fa2d2a383f177d508f872 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 16:30:43 +0900 Subject: [PATCH 05/15] feat: change img/txt order for attention and single blocks --- library/chroma_models.py | 75 +++++++++++++++------------------------- 1 file changed, 28 insertions(+), 47 deletions(-) diff --git a/library/chroma_models.py b/library/chroma_models.py index 06822a37..1b62f20f 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -236,7 +236,8 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention: we split the batch into each element - max_txt_len = txt_q.shape[-2] # max 512 + max_txt_len = torch.max(txt_seq_len).item() + img_len = img_q.shape[-2] # max 64 txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0)) txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0)) @@ -246,35 +247,25 @@ class DoubleStreamBlock(nn.Module): txt_attn = [] img_attn = [] for i in range(txt.shape[0]): - print(i) - print(f"len(txt_q) = {len(txt_q)}, len(img_q) = {len(img_q)}, txt_seq_len.shape = {txt_seq_len.shape}") - print(f"txt_seq_len[i] = {txt_seq_len[i]}, txt_q.shape = {txt_q[i].shape}, img_q.shape = {img_q[i].shape}") - txt_q_i = txt_q[i][:, :, : txt_seq_len[i]] + txt_q[i] = txt_q[i][:, :, : txt_seq_len[i]] + q = torch.cat((img_q[i], txt_q[i]), dim=2) txt_q[i] = None - img_q_i = img_q[i] img_q[i] = None - q = torch.cat((txt_q_i, img_q_i), dim=2) - del txt_q_i, img_q_i - txt_k_i = txt_k[i][:, :, : txt_seq_len[i]] + txt_k[i] = txt_k[i][:, :, : txt_seq_len[i]] + k = torch.cat((img_k[i], txt_k[i]), dim=2) txt_k[i] = None - img_k_i = img_k[i] img_k[i] = None - k = torch.cat((txt_k_i, img_k_i), dim=2) - del txt_k_i, img_k_i - txt_v_i = txt_v[i][:, :, : txt_seq_len[i]] + txt_v[i] = txt_v[i][:, :, : txt_seq_len[i]] + v = torch.cat((img_v[i], txt_v[i]), dim=2) txt_v[i] = None - img_v_i = img_v[i] img_v[i] = None - v = torch.cat((txt_v_i, img_v_i), dim=2) - del txt_v_i, img_v_i - attn = attention(q, k, v, pe=pe[i], attn_mask=None) # (1, L, D) - print(f"attn.shape = {attn.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}") + attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D) + img_attn_i = attn[:, :img_len, :] txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) - txt_attn_i[:, : txt_seq_len[i], :] = attn[:, : txt_seq_len[i], :] - img_attn_i = attn[:, txt_seq_len[i] :, :] + txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :] txt_attn.append(txt_attn_i) img_attn.append(img_attn_i) @@ -377,9 +368,7 @@ class SingleStreamBlock(nn.Module): def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - def _forward( - self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int - ) -> Tensor: + def _forward(self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: mod = distill_vec # replaced with compiled fn # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift @@ -393,25 +382,23 @@ class SingleStreamBlock(nn.Module): # attn = attention(q, k, v, pe=pe, attn_mask=mask) # compute attention: we split the batch into each element + max_txt_len = torch.max(txt_seq_len).item() + img_len = q.shape[-2] - max_txt_len q = list(torch.chunk(q, q.shape[0], dim=0)) k = list(torch.chunk(k, k.shape[0], dim=0)) v = list(torch.chunk(v, v.shape[0], dim=0)) attn = [] for i in range(x.size(0)): - q_i = torch.cat((q[i][:, :, : txt_seq_len[i]], q[i][:, :, max_txt_len:]), dim=2) + q[i] = q[i][:, :, : img_len + txt_seq_len[i]] + k[i] = k[i][:, :, : img_len + txt_seq_len[i]] + v[i] = v[i][:, :, : img_len + txt_seq_len[i]] + attn_trimmed = attention(q[i], k[i], v[i], pe=pe[i : i + 1, :, : img_len + txt_seq_len[i]], attn_mask=None) q[i] = None - k_i = torch.cat((k[i][:, :, : txt_seq_len[i]], k[i][:, :, max_txt_len:]), dim=2) k[i] = None - v_i = torch.cat((v[i][:, :, : txt_seq_len[i]], v[i][:, :, max_txt_len:]), dim=2) v[i] = None - attn_trimmed = attention(q_i, k_i, v_i, pe=pe[i], attn_mask=None) - print( - f"attn_trimmed.shape = {attn_trimmed.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}, x.shape = {x.shape}" - ) attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) - attn_i[:, : txt_seq_len[i], :] = attn_trimmed[:, : txt_seq_len[i], :] - attn_i[:, max_txt_len:, :] = attn_trimmed[:, txt_seq_len[i] :, :] + attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed attn.append(attn_i) attn = torch.cat(attn, dim=0) @@ -422,11 +409,11 @@ class SingleStreamBlock(nn.Module): # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) - def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, max_txt_len, use_reentrant=False) + return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, use_reentrant=False) else: - return self._forward(x, pe, distill_vec, txt_seq_len, max_txt_len) + return self._forward(x, pe, distill_vec, txt_seq_len) class LastLayer(nn.Module): @@ -677,9 +664,6 @@ class Chroma(Flux): mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 - # calculate text length for each batch instead of masking txt_emb_len = txt.shape[1] txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) @@ -689,12 +673,9 @@ class Chroma(Flux): # trim txt embedding to the text length txt = txt[:, :max_txt_len, :] - # split positional encoding into each element of the batch, and trim masked tokens - print(f"pe shape = {pe.shape} dtype = {pe.dtype}, txt_seq_len = {txt_seq_len}") - pe = list(torch.chunk(pe, pe.shape[0], dim=0)) - for i in range(len(pe)): - # trim positional encoding to the text length - pe[i] = torch.cat([pe[i][:, :, : txt_seq_len[i]], pe[i][:, :, txt_emb_len:]], dim=2) + # create positional encoding for the text and image + ids = torch.cat((img_ids, txt_ids[:, :max_txt_len]), dim=1) # reverse order of ids for faster attention + pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 for i, block in enumerate(self.double_blocks): if self.blocks_to_swap: @@ -710,19 +691,19 @@ class Chroma(Flux): if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) - img = torch.cat((txt, img), 1) + img = torch.cat((img, txt), 1) for i, block in enumerate(self.single_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(i) single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len, max_txt_len=max_txt_len) + img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len) if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) - img = img[:, txt.shape[1] :, ...] + img = img[:, :-max_txt_len, ...] final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img From b4e862626aaba996ffe8b7f942ce5ce21d762919 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 19:00:09 +0900 Subject: [PATCH 06/15] feat: add LoRA training support for Chroma --- flux_minimal_inference.py | 2 +- flux_train.py | 2 +- flux_train_control_net.py | 7 +- flux_train_network.py | 102 +++++++++------------ library/chroma_models.py | 50 ++++++---- library/flux_models.py | 177 +----------------------------------- library/flux_train_utils.py | 19 ++-- library/flux_utils.py | 43 ++++++++- library/sai_model_spec.py | 14 ++- library/train_util.py | 2 +- 10 files changed, 158 insertions(+), 260 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 550904d2..86e8e1b1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -468,7 +468,7 @@ if __name__ == "__main__": # t5xxl = accelerator.prepare(t5xxl) # DiT - model_type, is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) + is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type) model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype diff --git a/flux_train.py b/flux_train.py index 1d2cc68b..84db34cf 100644 --- a/flux_train.py +++ b/flux_train.py @@ -270,7 +270,7 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - model_type, _, flux = flux_utils.load_flow_model( + _, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 3c038c32..93c20dab 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -68,6 +68,11 @@ def train(args): if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check + if args.model_type != "flux": + raise ValueError( + f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。" + ) + # assert ( # not args.weighted_captions # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" @@ -258,7 +263,7 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - model_type, is_schnell, flux = flux_utils.load_flow_model( + is_schnell, flux = flux_utils.load_flow_model( args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" ) flux.requires_grad_(False) diff --git a/flux_train_network.py b/flux_train_network.py index b2bf8e7c..1b61ac72 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.sample_prompts_te_outputs = None self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False + self.model_type: Optional[str] = None def assert_extra_args( self, @@ -45,6 +46,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) + self.model_type = args.model_type # "flux" or "chroma" + if self.model_type != "chroma": + self.use_clip_l = True + else: + self.use_clip_l = False # Chroma does not use CLIP-L + if args.fp8_base_unet: args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 @@ -60,7 +67,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" # prepare CLIP-L/T5XXL training flags - self.train_clip_l = not args.network_train_unet_only + self.train_clip_l = not args.network_train_unet_only and self.use_clip_l self.train_t5xxl = False # default is False even if args.network_train_unet_only is False if args.max_token_length is not None: @@ -95,8 +102,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_base else weight_dtype # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future - self.model_type, self.is_schnell, model = flux_utils.load_flow_model( - args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, model_type="flux" + _, model = flux_utils.load_flow_model( + args.pretrained_model_name_or_path, + loading_dtype, + "cpu", + disable_mmap=args.disable_mmap_load_safetensors, + model_type=self.model_type, ) if args.fp8_base: # check dtype of model @@ -120,7 +131,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") model.enable_block_swap(args.blocks_to_swap, accelerator.device) - clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + if self.use_clip_l: + clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + else: + clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L clip_l.eval() # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) @@ -141,13 +155,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model + model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA + return model_version, [clip_l, t5xxl], ae, model def get_tokenize_strategy(self, args): - _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + # This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here. + # Instead, we analyze the checkpoint state to determine if it is schnell. + if args.model_type != "chroma": + _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) + else: + is_schnell = False + self.is_schnell = is_schnell if args.t5xxl_max_token_length is None: - if is_schnell: + if self.is_schnell: t5xxl_max_token_length = 256 else: t5xxl_max_token_length = 512 @@ -268,23 +289,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device) - # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): - # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype - - # # get size embeddings - # orig_size = batch["original_sizes_hw"] - # crop_size = batch["crop_top_lefts"] - # target_size = batch["target_sizes_hw"] - # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) - - # # concat embeddings - # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds - # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) - # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) - - # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - # return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) @@ -292,36 +296,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): flux_train_utils.sample_images( accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs ) - # return - - """ - class FluxUpperLowerWrapper(torch.nn.Module): - def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): - super().__init__() - self.flux_upper = flux_upper - self.flux_lower = flux_lower - self.target_device = device - - def prepare_block_swap_before_forward(self): - pass - - def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): - self.flux_lower.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_upper.to(self.target_device) - img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) - self.flux_upper.to("cpu") - clean_memory_on_device(self.target_device) - self.flux_lower.to(self.target_device) - return self.flux_lower(img, txt, vec, pe, txt_attention_mask) - - wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) - clean_memory_on_device(accelerator.device) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs - ) - clean_memory_on_device(accelerator.device) - """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) @@ -366,7 +340,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # ensure guidance_scale in args is float guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - # ensure the hidden state will require grad + # get modulation vectors for Chroma + input_vec = None + if self.model_type == "chroma": + input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz) + if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: @@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) + if input_vec is not None: + input_vec.requires_grad_(True) # Predict the noise residual l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps / 1000, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, + input_vec=input_vec, ) return model_pred @@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps, guidance_vec=guidance_vec, t5_attn_mask=t5_attn_mask, + input_vec=input_vec, ) # unpack latents @@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps[diff_output_pr_indices], guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step @@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return loss def get_sai_model_spec(self, args): - return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") + if self.model_type != "chroma": + model_description = "schnell" if self.is_schnell else "dev" + else: + model_description = "chroma" + return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description) def update_metadata(self, metadata, args): + metadata["ss_model_type"] = args.model_type metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_logit_mean"] = args.logit_mean diff --git a/library/chroma_models.py b/library/chroma_models.py index 1b62f20f..e5d3b547 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -601,13 +601,30 @@ class Chroma(Flux): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def get_mod_vectors( - self, - timesteps: Tensor, - guidance: Tensor | None = None, - batch_size: int | None = None, - requires_grad: bool = False, - ) -> Tensor: + def get_model_type(self) -> str: + return "chroma" + + def enable_gradient_checkpointing(self, cpu_offload: bool = False): + self.gradient_checkpointing = True + self.cpu_offload_checkpointing = cpu_offload + + self.distilled_guidance_layer.enable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.enable_gradient_checkpointing() + + print(f"Chroma: Gradient checkpointing enabled.") + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.cpu_offload_checkpointing = False + + self.distilled_guidance_layer.disable_gradient_checkpointing() + for block in self.double_blocks + self.single_blocks: + block.disable_gradient_checkpointing() + + print("Chroma: Gradient checkpointing disabled.") + + def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) @@ -619,10 +636,7 @@ class Chroma(Flux): timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - if requires_grad: - input_vec = input_vec.requires_grad_(True) - mod_vectors = self.distilled_guidance_layer(input_vec) - return mod_vectors + return input_vec def forward( self, @@ -637,7 +651,7 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, attn_padding: int = 1, - mod_vectors: Tensor | None = None, + input_vec: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -651,7 +665,7 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - if mod_vectors is None: + if input_vec is None: # TODO: # need to fix grad accumulation issue here for now it's in no grad mode # besides, i don't want to wash out the PFP that's trained on this model weights anyway @@ -659,14 +673,18 @@ class Chroma(Flux): # alternatively doing forward pass for every block manually is doable but slow # custom backward probably be better with torch.no_grad(): - # kohya-ss: I'm not sure why requires_grad is set to True here - mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True) + input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) + # kohya-ss: I'm not sure why requires_grad is set to True here + input_vec.requires_grad = True + mod_vectors = self.distilled_guidance_layer(input_vec) + else: + mod_vectors = self.distilled_guidance_layer(input_vec) mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) # calculate text length for each batch instead of masking txt_emb_len = txt.shape[1] - txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) + txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, ) txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..6f889755 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -930,6 +930,9 @@ class Flux(nn.Module): self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) + def get_model_type(self) -> str: + return "flux" + @property def device(self): return next(self.parameters()).device @@ -1018,6 +1021,7 @@ class Flux(nn.Module): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, + input_vec: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1169,7 +1173,7 @@ class ControlNetFlux(nn.Module): nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1, stride=2), nn.SiLU(), - zero_module(nn.Conv2d(16, 16, 3, padding=1)) + zero_module(nn.Conv2d(16, 16, 3, padding=1)), ) @property @@ -1320,174 +1324,3 @@ class ControlNetFlux(nn.Module): controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,) return controlnet_block_samples, controlnet_single_block_samples - - -""" -class FluxUpper(nn.Module): - "" - Transformer model for flow matching on sequences. - "" - - def __init__(self, params: FluxParams): - super().__init__() - - self.params = params - self.in_channels = params.in_channels - self.out_channels = self.in_channels - if params.hidden_size % params.num_heads != 0: - raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") - pe_dim = params.hidden_size // params.num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) - self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) - - self.double_blocks = nn.ModuleList( - [ - DoubleStreamBlock( - self.hidden_size, - self.num_heads, - mlp_ratio=params.mlp_ratio, - qkv_bias=params.qkv_bias, - ) - for _ in range(params.depth) - ] - ) - - self.gradient_checkpointing = False - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing = True - - self.time_in.enable_gradient_checkpointing() - self.vector_in.enable_gradient_checkpointing() - if self.guidance_in.__class__ != nn.Identity: - self.guidance_in.enable_gradient_checkpointing() - - for block in self.double_blocks: - block.enable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing enabled.") - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing = False - - self.time_in.disable_gradient_checkpointing() - self.vector_in.disable_gradient_checkpointing() - if self.guidance_in.__class__ != nn.Identity: - self.guidance_in.disable_gradient_checkpointing() - - for block in self.double_blocks: - block.disable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing disabled.") - - def forward( - self, - img: Tensor, - img_ids: Tensor, - txt: Tensor, - txt_ids: Tensor, - timesteps: Tensor, - y: Tensor, - guidance: Tensor | None = None, - txt_attention_mask: Tensor | None = None, - ) -> Tensor: - if img.ndim != 3 or txt.ndim != 3: - raise ValueError("Input img and txt tensors must have 3 dimensions.") - - # running on sequences img - img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) - if self.params.guidance_embed: - if guidance is None: - raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) - vec = vec + self.vector_in(y) - txt = self.txt_in(txt) - - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) - - for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - - return img, txt, vec, pe - - -class FluxLower(nn.Module): - "" - Transformer model for flow matching on sequences. - "" - - def __init__(self, params: FluxParams): - super().__init__() - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.out_channels = params.in_channels - - self.single_blocks = nn.ModuleList( - [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(params.depth_single_blocks) - ] - ) - - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) - - self.gradient_checkpointing = False - - @property - def device(self): - return next(self.parameters()).device - - @property - def dtype(self): - return next(self.parameters()).dtype - - def enable_gradient_checkpointing(self): - self.gradient_checkpointing = True - - for block in self.single_blocks: - block.enable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing enabled.") - - def disable_gradient_checkpointing(self): - self.gradient_checkpointing = False - - for block in self.single_blocks: - block.disable_gradient_checkpointing() - - print("FLUX: Gradient checkpointing disabled.") - - def forward( - self, - img: Tensor, - txt: Tensor, - vec: Tensor | None = None, - pe: Tensor | None = None, - txt_attention_mask: Tensor | None = None, - ) -> Tensor: - img = torch.cat((txt, img), 1) - for block in self.single_blocks: - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) - img = img[:, txt.shape[1] :, ...] - - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) - return img -""" diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 8392e559..f3eb8199 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,9 +154,8 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - # TODO refactor variable names - cfg_scale = prompt_dict.get("guidance_scale", 1.0) - emb_guidance_scale = prompt_dict.get("scale", 3.5) + emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5) + cfg_scale = prompt_dict.get("scale", 1.0) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -242,7 +241,7 @@ def sample_image_inference( dtype=weight_dtype, generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None, ) - timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True + timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None @@ -403,8 +402,8 @@ def denoise( y=torch.cat([neg_l_pooled, vec], dim=0), block_controlnet_hidden_states=block_samples, block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, + timesteps=t_vec.repeat(2), + guidance=guidance_vec.repeat(2), txt_attention_mask=nc_c_t5_attn_mask, ) neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) @@ -680,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + parser.add_argument( + "--model_type", + type=str, + choices=["flux", "chroma"], + default="flux", + help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)", + ) diff --git a/library/flux_utils.py b/library/flux_utils.py index dda7c789..3f0a0d63 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -23,6 +23,7 @@ from library.utils import load_safetensors MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +MODEL_VERSION_CHROMA = "chroma" def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: @@ -97,7 +98,7 @@ def load_flow_model( device: Union[str, torch.device], disable_mmap: bool = False, model_type: str = "flux", -) -> Tuple[str, bool, flux_models.Flux]: +) -> Tuple[bool, flux_models.Flux]: if model_type == "flux": is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL @@ -140,7 +141,7 @@ def load_flow_model( info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") - return model_type, is_schnell, model + return is_schnell, model elif model_type == "chroma": from . import chroma_models @@ -166,7 +167,7 @@ def load_flow_model( info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Chroma: {info}") is_schnell = False # Chroma is not schnell - return model_type, is_schnell, model + return is_schnell, model else: raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.") @@ -203,6 +204,42 @@ def load_controlnet( return controlnet +def dummy_clip_l() -> torch.nn.Module: + """ + Returns a dummy CLIP-L model with the output shape of (N, 77, 768). + """ + return DummyCLIPL() + + +class DummyTextModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.embeddings = torch.nn.Parameter(torch.zeros(1)) + + +class DummyCLIPL(torch.nn.Module): + def __init__(self): + super().__init__() + self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter + self.text_model = DummyTextModel() + + @property + def device(self): + return self.dummy_param.device + + @property + def dtype(self): + return self.dummy_param.dtype + + def forward(self, *args, **kwargs): + """ + Returns a dummy output with the shape of (N, 77, 768). + """ + batch_size = args[0].shape[0] if args else 1 + return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)} + + def load_clip_l( ckpt_path: Optional[str], dtype: torch.dtype, diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 8896c047..662a6b2e 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -60,6 +60,8 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc. # ARCH_SD3_UNKNOWN = "stable-diffusion-3" ARCH_FLUX_1_DEV = "flux-1-dev" +ARCH_FLUX_1_SCHNELL = "flux-1-schnell" +ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma ARCH_FLUX_1_UNKNOWN = "flux-1" ADAPTER_LORA = "lora" @@ -69,6 +71,7 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI" IMPL_DIFFUSERS = "diffusers" IMPL_FLUX = "https://github.com/black-forest-labs/flux" +IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma" PRED_TYPE_EPSILON = "epsilon" PRED_TYPE_V = "v" @@ -125,7 +128,7 @@ def build_metadata( flux: Optional[str] = None, ): """ - sd3: only supports "m", flux: only supports "dev" + sd3: only supports "m", flux: supports "dev", "schnell" or "chroma" """ # if state_dict is None, hash is not calculated @@ -144,6 +147,10 @@ def build_metadata( elif flux is not None: if flux == "dev": arch = ARCH_FLUX_1_DEV + elif flux == "schnell": + arch = ARCH_FLUX_1_SCHNELL + elif flux == "chroma": + arch = ARCH_FLUX_1_CHROMA else: arch = ARCH_FLUX_1_UNKNOWN elif v2: @@ -166,7 +173,10 @@ def build_metadata( if flux is not None: # Flux - impl = IMPL_FLUX + if flux == "chroma": + impl = IMPL_CHROMA + else: + impl = IMPL_FLUX elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI diff --git a/library/train_util.py b/library/train_util.py index 36d419fd..b09963fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3482,7 +3482,7 @@ def get_sai_model_spec( textual_inversion: bool, is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA sd3: str = None, - flux: str = None, + flux: str = None, # "dev", "schnell" or "chroma" ): timestamp = time.time() From 0b763ef1f17fc9117b630c3478c6ae02437ac07e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 20:53:06 +0900 Subject: [PATCH 07/15] feat: fix timestep for input_vec for Chroma --- flux_train_network.py | 4 +--- library/chroma_models.py | 36 ++++++++++++++++++++++++++++++------ library/flux_models.py | 3 +++ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 1b61ac72..13e9ae2a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -341,9 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # get modulation vectors for Chroma - input_vec = None - if self.model_type == "chroma": - input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz) + input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) diff --git a/library/chroma_models.py b/library/chroma_models.py index e5d3b547..b9c54db4 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -223,7 +223,10 @@ class DoubleStreamBlock(nn.Module): # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift) img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention @@ -232,7 +235,10 @@ class DoubleStreamBlock(nn.Module): # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift) txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention: we split the batch into each element @@ -263,9 +269,11 @@ class DoubleStreamBlock(nn.Module): img_v[i] = None attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D) + del q, k, v img_attn_i = attn[:, :img_len, :] txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :] + del attn txt_attn.append(txt_attn_i) img_attn.append(img_attn_i) @@ -279,27 +287,31 @@ class DoubleStreamBlock(nn.Module): # attn = attention(q, k, v, pe=pe, attn_mask=mask) # txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks # replaced with compiled fn # img = img + img_mod1.gate * self.img_attn.proj(img_attn) # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + del img_attn, img_mod1 img = self.modulation_gate_fn( img, img_mod2.gate, self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)), ) + del img_mod2 - # calculate the txt bloks + # calculate the txt blocks # replaced with compiled fn # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + del txt_attn, txt_mod1 txt = self.modulation_gate_fn( txt, txt_mod2.gate, self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)), ) + del txt_mod2 return img, txt @@ -374,8 +386,10 @@ class SingleStreamBlock(nn.Module): # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + del x_mod q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del qkv q, k = self.norm(q, k, v) # # compute attention @@ -399,12 +413,15 @@ class SingleStreamBlock(nn.Module): attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed + del attn_trimmed attn.append(attn_i) attn = torch.cat(attn, dim=0) # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) + del attn, mlp # replaced with compiled fn # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) @@ -625,6 +642,7 @@ class Chroma(Flux): print("Chroma: Gradient checkpointing disabled.") def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + # print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}") distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) @@ -656,6 +674,7 @@ class Chroma(Flux): # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" # ) + # print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}") # print(f"timesteps: {timesteps}, guidance: {guidance}") if img.ndim != 3 or txt.ndim != 3: @@ -687,6 +706,7 @@ class Chroma(Flux): txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, ) txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch + # print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}") # trim txt embedding to the text length txt = txt[:, :max_txt_len, :] @@ -700,23 +720,27 @@ class Chroma(Flux): self.offloader_double.wait_for_block(i) # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin") + txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin") double_mod = [img_mod, txt_mod] + del img_mod, txt_mod img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len) + del double_mod if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) img = torch.cat((img, txt), 1) + del txt for i, block in enumerate(self.single_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(i) - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin") img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len) + del single_mod if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) diff --git a/library/flux_models.py b/library/flux_models.py index 6f889755..2a2fe5f8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1009,6 +1009,9 @@ class Flux(nn.Module): self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + return None # FLUX.1 does not use input_vec, but Chroma does. + def forward( self, img: Tensor, From 77a160d8867422ffdf7be34d8879fe29e05a8040 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 21:25:43 +0900 Subject: [PATCH 08/15] fix: skip LoRA creation for None text encoders (CLIP-L for Chroma) --- networks/lora_flux.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 0b30f1b8..ddc91608 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -892,6 +892,9 @@ class LoRANetwork(torch.nn.Module): skipped_te = [] for i, text_encoder in enumerate(text_encoders): index = i + if text_encoder is None: + logger.info(f"Text Encoder {index+1} is None, skipping LoRA creation for this encoder.") + continue if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False break From 32f06012a750737699bc4872173c9e960f000980 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Mon, 21 Jul 2025 21:48:06 +0900 Subject: [PATCH 09/15] doc: update flux train document and add about breaking changes in sample generation prompts --- README-ja.md | 13 +- README.md | 12 +- docs/flux_train_network.md | 686 ++++++++++++++++++++----------------- 3 files changed, 396 insertions(+), 315 deletions(-) diff --git a/README-ja.md b/README-ja.md index 60249f61..c310dd8a 100644 --- a/README-ja.md +++ b/README-ja.md @@ -155,11 +155,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。 - * `--n` Negative prompt up to the next option. - * `--w` Specifies the width of the generated image. - * `--h` Specifies the height of the generated image. - * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * `--s` Specifies the number of steps in the generation. + * `--n` ネガティブプロンプト(次のオプションまで) + * `--w` 生成画像の幅を指定 + * `--h` 生成画像の高さを指定 + * `--d` 生成画像のシード値を指定 + * `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください + * `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください + * `--s` 生成時のステップ数を指定 `( )` や `[ ]` などの重みづけも動作します。 diff --git a/README.md b/README.md index 3ef16593..9ba1cbfc 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,13 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed ### Recent Updates +Jul XX, 2025: +- **Breaking Change**: For FLUX.1 and Chroma training, the CFG scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details. + +- Support for [Chroma](https://huggingface.co/lodestones/Chroma) has been added in PR [#2157](https://github.com/kohya-ss/sd-scripts/pull/2157). Thank you to lodestones for the high-quality model. + - Chroma is a new model based on FLUX.1 schnell. In this repository, `flux_train_network.py` is used for training LoRAs for Chroma with `--model_type chroma`. + - Please refer to the [FLUX.1 LoRA training documentation](./docs/flux_train_network.md) for more details. + Jul 21, 2025: - Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions. - Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details. @@ -1367,9 +1374,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. - * `--l` Specifies the CFG scale of the generated image. - * In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility. - * `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models. + * `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG. + * `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md index 2b7ff749..f324b959 100644 --- a/docs/flux_train_network.md +++ b/docs/flux_train_network.md @@ -4,6 +4,13 @@ Status: reviewed This document explains how to train LoRA models for the FLUX.1 model using `flux_train_network.py` included in the `sd-scripts` repository. +
+日本語 + +このドキュメントでは、`sd-scripts`リポジトリに含まれる`flux_train_network.py`を使用して、FLUX.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 + +
+ ## 1. Introduction / はじめに `flux_train_network.py` trains additional networks such as LoRA on the FLUX.1 model, which uses a transformer-based architecture different from Stable Diffusion. Two text encoders, CLIP-L and T5-XXL, and a dedicated AutoEncoder are used. @@ -15,21 +22,73 @@ This guide assumes you know the basics of LoRA training. For common options see * The repository is cloned and the Python environment is ready. * A training dataset is prepared. See the dataset configuration guide. +
+日本語 + +`flux_train_network.py`は、FLUX.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。FLUX.1はStable Diffusionとは異なるアーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。 + +このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。 + +**前提条件:** + +* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 +* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) + +
+ ## 2. Differences from `train_network.py` / `train_network.py` との違い -`flux_train_network.py` is based on `train_network.py` but adapted for FLUX.1. Main differences include required arguments for the FLUX.1 model, CLIP-L, T5-XXL and AE, different model structure, and some incompatible options from Stable Diffusion. +`flux_train_network.py` is based on `train_network.py` but adapted for FLUX.1. Main differences include: + +* **Target model:** FLUX.1 model (dev or schnell version). +* **Model structure:** Unlike Stable Diffusion, FLUX.1 uses a Transformer-based architecture with two text encoders (CLIP-L and T5-XXL) and a dedicated AutoEncoder (AE) instead of VAE. +* **Required arguments:** Additional arguments for FLUX.1 model, CLIP-L, T5-XXL, and AE model files. +* **Incompatible options:** Some Stable Diffusion-specific arguments (e.g., `--v2`, `--clip_skip`, `--max_token_length`) are not used in FLUX.1 training. +* **FLUX.1-specific arguments:** Additional arguments for FLUX.1-specific training parameters like timestep sampling and guidance scale. + +
+日本語 + +`flux_train_network.py`は`train_network.py`をベースに、FLUX.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。 + +* **対象モデル:** FLUX.1モデル(dev版またはschnell版)を対象とします。 +* **モデル構造:** Stable Diffusionとは異なり、FLUX.1はTransformerベースのアーキテクチャを持ちます。Text EncoderとしてCLIP-LとT5-XXLの二つを使用し、VAEの代わりに専用のAutoEncoder (AE) を使用します。 +* **必須の引数:** FLUX.1モデル、CLIP-L、T5-XXL、AEの各モデルファイルを指定する引数が追加されています。 +* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)はFLUX.1の学習では使用されません。 +* **FLUX.1特有の引数:** タイムステップのサンプリング方法やガイダンススケールなど、FLUX.1特有の学習パラメータを指定する引数が追加されています。 + +
## 3. Preparation / 準備 Before starting training you need: 1. **Training script:** `flux_train_network.py` -2. **FLUX.1 model file** and text encoder files (`clip_l`, `t5xxl`) and AE file. -3. **Dataset definition file (.toml)** such as `my_flux_dataset_config.toml`. +2. **FLUX.1 model file:** Base FLUX.1 model `.safetensors` file (e.g., `flux1-dev.safetensors`). +3. **Text Encoder model files:** + - CLIP-L model `.safetensors` file (e.g., `clip_l.safetensors`) + - T5-XXL model `.safetensors` file (e.g., `t5xxl.safetensors`) +4. **AutoEncoder model file:** FLUX.1-compatible AE model `.safetensors` file (e.g., `ae.safetensors`). +5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration (e.g., `my_flux_dataset_config.toml`). + +
+日本語 + +学習を開始する前に、以下のファイルが必要です。 + +1. **学習スクリプト:** `flux_train_network.py` +2. **FLUX.1モデルファイル:** 学習のベースとなるFLUX.1モデルの`.safetensors`ファイル(例: `flux1-dev.safetensors`)。 +3. **Text Encoderモデルファイル:** + - CLIP-Lモデルの`.safetensors`ファイル。例として`clip_l.safetensors`を使用します。 + - T5-XXLモデルの`.safetensors`ファイル。例として`t5xxl.safetensors`を使用します。 +4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。 +5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。例として`my_flux_dataset_config.toml`を使用します。 + +
## 4. Running the Training / 学習の実行 -Run `flux_train_network.py` from the terminal with FLUX.1 specific arguments. Example: +Run `flux_train_network.py` from the terminal with FLUX.1 specific arguments. Here's a basic command example: ```bash accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ @@ -54,369 +113,318 @@ accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ --gradient_checkpointing \ --guidance_scale=1.0 \ --timestep_sampling="flux_shift" \ + --model_prediction_type="raw" \ --blocks_to_swap=18 \ --cache_text_encoder_outputs \ --cache_latents ``` +### Training Chroma Models + +If you want to train a Chroma model, specify `--model_type=chroma`. Chroma does not use CLIP-L, so the `--clip_l` argument is not needed. T5XXL and AE are same as FLUX.1. The command would look like this: + +```bash +accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ + --pretrained_model_name_or_path="" \ + --model_type=chroma \ + --t5xxl="" \ + --ae="" \ + --dataset_config="my_flux_dataset_config.toml" \ + --output_dir="" \ + --output_name="my_chroma_lora" \ + --guidance_scale=0.0 \ + --timestep_sampling="sigmoid" \ + --apply_t5_attn_mask \ + ... +``` + +Note that for Chroma models, `--guidance_scale=0.0` is required to disable guidance scale, and `--apply_t5_attn_mask` is needed to apply attention masks for T5XXL Text Encoder. + +
+日本語 + +学習は、ターミナルから`flux_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、FLUX.1特有の引数を指定する必要があります。 + +コマンドラインの例は英語のドキュメントを参照してください。 + +#### Chromaモデルの学習 + +Chromaモデルを学習したい場合は、`--model_type=chroma`を指定します。ChromaはCLIP-Lを使用しないため、`--clip_l`引数は不要です。T5XXLとAEはFLUX.1と同様です。 + +コマンドラインの例は英語のドキュメントを参照してください。 + +
+ ### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 -The script adds FLUX.1 specific arguments such as guidance scale, timestep sampling, block swapping, and options for training CLIP-L and T5-XXL LoRA modules. Some Stable Diffusion options like `--v2` and `--clip_skip` are not used. +The script adds FLUX.1 specific arguments. For common arguments (like `--output_dir`, `--output_name`, `--network_module`, etc.), see the [`train_network.py` guide](train_network.md). + +#### Model-related [Required] + +* `--pretrained_model_name_or_path=""` **[Required]** + - Specifies the path to the base FLUX.1 or Chroma model `.safetensors` file. Diffusers format directories are not currently supported. +* `--model_type=` + - Specifies the type of base model for training. Choose from `flux` or `chroma`. Default is `flux`. +* `--clip_l=""` **[Required when flux is selected]** + - Specifies the path to the CLIP-L Text Encoder model `.safetensors` file. Not needed when `--model_type=chroma`. +* `--t5xxl=""` **[Required]** + - Specifies the path to the T5-XXL Text Encoder model `.safetensors` file. +* `--ae=""` **[Required]** + - Specifies the path to the FLUX.1-compatible AutoEncoder model `.safetensors` file. + +#### FLUX.1 Training Parameters + +* `--guidance_scale=` + - FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `1.0` to disable guidance scale. Default is `3.5`, so be sure to specify this. Usually ignored for schnell version. + - Chroma requires `--guidance_scale=0.0` to disable guidance scale. +* `--timestep_sampling=` + - Specifies the sampling method for timesteps (noise levels) during training. Choose from `sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`. Default is `sigma`. Recommended is `flux_shift`. For Chroma models, `sigmoid` is recommended. +* `--sigmoid_scale=` + - Scale factor when `timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default and recommended value is `1.0`. +* `--model_prediction_type=` + - Specifies what the model predicts. Choose from `raw` (use prediction as-is), `additive` (add to noise input), `sigma_scaled` (apply sigma scaling). Default is `sigma_scaled`. Recommended is `raw`. +* `--discrete_flow_shift=` + - Specifies the shift value for the scheduler used in Flow Matching. Default is `3.0`. This value is ignored when `timestep_sampling` is set to other than `shift`. + +#### Memory/Speed Related + +* `--fp8_base` + - Enables training in FP8 format for FLUX.1, CLIP-L, and T5-XXL. This can significantly reduce VRAM usage, but the training results may vary. +* `--blocks_to_swap=` **[Experimental Feature]** + - Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`. + - Cannot be used with `--cpu_offload_checkpointing`. +* `--cache_text_encoder_outputs` + - Caches the outputs of CLIP-L and T5-XXL. This reduces memory usage. +* `--cache_latents`, `--cache_latents_to_disk` + - Caches the outputs of AE. Similar functionality to [sdxl_train_network.py](sdxl_train_network.md). + +#### Incompatible/Deprecated Arguments + +* `--v2`, `--v_parameterization`, `--clip_skip`: These are Stable Diffusion-specific arguments and are not used in FLUX.1 training. +* `--max_token_length`: This is an argument for Stable Diffusion v1/v2. For FLUX.1, use `--t5xxl_max_token_length`. +* `--split_mode`: Deprecated argument. Use `--blocks_to_swap` instead. + +
+日本語 + +[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のFLUX.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。 + +コマンドラインの例と詳細な引数の説明は英語のドキュメントを参照してください。 + +
### 4.2. Starting Training / 学習の開始 -Training begins once you run the command with the required options. Log checking is the same as in `train_network.py`. +Training begins once you run the command with the required options. Log checking is the same as in [`train_network.py`](train_network.md#32-starting-the-training--学習の開始). + +
+日本語 + +必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 + +
## 5. Using the Trained Model / 学習済みモデルの利用 After training, a LoRA model file is saved in `output_dir` and can be used in inference environments supporting FLUX.1 (e.g. ComfyUI + Flux nodes). -## 6. Others / その他 - -Additional notes on VRAM optimization, training options, multi-resolution datasets, block selection and text encoder LoRA are provided in the Japanese section. -
日本語 - - -# `flux_train_network.py` を用いたFLUX.1モデルのLoRA学習ガイド - -このドキュメントでは、`sd-scripts`リポジトリに含まれる`flux_train_network.py`を使用して、FLUX.1モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。 - -## 1. はじめに - -`flux_train_network.py`は、FLUX.1モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。FLUX.1はStable Diffusionとは異なるアーキテクチャを持つ画像生成モデルであり、このスクリプトを使用することで、特定のキャラクターや画風を再現するLoRAモデルを作成できます。 - -このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sdxl_train_network.py`](sdxl_train_network.md) と同様のものがあるため、そちらも参考にしてください。 - -**前提条件:** - -* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。 -* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](link/to/dataset/config/doc)を参照してください) - -## 2. `train_network.py` との違い - -`flux_train_network.py`は`train_network.py`をベースに、FLUX.1モデルに対応するための変更が加えられています。主な違いは以下の通りです。 - -* **対象モデル:** FLUX.1モデル(dev版またはschnell版)を対象とします。 -* **モデル構造:** Stable Diffusionとは異なり、FLUX.1はTransformerベースのアーキテクチャを持ちます。Text EncoderとしてCLIP-LとT5-XXLの二つを使用し、VAEの代わりに専用のAutoEncoder (AE) を使用します。 -* **必須の引数:** FLUX.1モデル、CLIP-L、T5-XXL、AEの各モデルファイルを指定する引数が追加されています。 -* **一部引数の非互換性:** Stable Diffusion向けの引数の一部(例: `--v2`, `--clip_skip`, `--max_token_length`)はFLUX.1の学習では使用されません。 -* **FLUX.1特有の引数:** タイムステップのサンプリング方法やガイダンススケールなど、FLUX.1特有の学習パラメータを指定する引数が追加されています。 - -## 3. 準備 - -学習を開始する前に、以下のファイルが必要です。 - -1. **学習スクリプト:** `flux_train_network.py` -2. **FLUX.1モデルファイル:** 学習のベースとなるFLUX.1モデルの`.safetensors`ファイル(例: `flux1-dev.safetensors`)。 -3. **Text Encoderモデルファイル:** - * CLIP-Lモデルの`.safetensors`ファイル。例として`clip_l.safetensors`を使用します。 - * T5-XXLモデルの`.safetensors`ファイル。例として`t5xxl.safetensors`を使用します。 -4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。 -5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。 - - * 例として`my_flux_dataset_config.toml`を使用します。 - -## 4. 学習の実行 - -学習は、ターミナルから`flux_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、FLUX.1特有の引数を指定する必要があります。 - -以下に、基本的なコマンドライン実行例を示します。 - -```bash -accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py - --pretrained_model_name_or_path="" - --clip_l="" - --t5xxl="" - --ae="" - --dataset_config="my_flux_dataset_config.toml" - --output_dir="" - --output_name="my_flux_lora" - --save_model_as=safetensors - --network_module=networks.lora_flux - --network_dim=16 - --network_alpha=1 - --learning_rate=1e-4 - --optimizer_type="AdamW8bit" - --lr_scheduler="constant" - --sdpa - --max_train_epochs=10 - --save_every_n_epochs=1 - --mixed_precision="fp16" - --gradient_checkpointing - --guidance_scale=1.0 - --timestep_sampling="flux_shift" - --blocks_to_swap=18 - --cache_text_encoder_outputs - --cache_latents -``` - -※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。 - -### 4.1. 主要なコマンドライン引数の解説(`train_network.py`からの追加・変更点) - -[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のFLUX.1特有の引数を指定します。共通の引数(`--output_dir`, `--output_name`, `--network_module`, `--network_dim`, `--network_alpha`, `--learning_rate`など)については、上記ガイドを参照してください。 - -#### モデル関連 [必須] - -* `--pretrained_model_name_or_path=""` **[必須]** - * 学習のベースとなるFLUX.1モデル(dev版またはschnell版)の`.safetensors`ファイルのパスを指定します。Diffusers形式のディレクトリは現在サポートされていません。 -* `--clip_l=""` **[必須]** - * CLIP-L Text Encoderモデルの`.safetensors`ファイルのパスを指定します。 -* `--t5xxl=""` **[必須]** - * T5-XXL Text Encoderモデルの`.safetensors`ファイルのパスを指定します。 -* `--ae=""` **[必須]** - * FLUX.1に対応するAutoEncoderモデルの`.safetensors`ファイルのパスを指定します。 - -#### FLUX.1 学習パラメータ - -* `--guidance_scale=` - * FLUX.1 dev版は特定のガイダンススケール値で蒸留されていますが、学習時には `1.0` を指定してガイダンススケールを無効化します。デフォルトは`3.5`ですので、必ず指定してください。schnell版では通常無視されます。 -* `--timestep_sampling=` - * 学習時に使用するタイムステップ(ノイズレベル)のサンプリング方法を指定します。`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift` から選択します。デフォルトは `sigma` です。推奨は `flux_shift` です。 -* `--sigmoid_scale=` - * `timestep_sampling` に `sigmoid` または `shift`, `flux_shift` を指定した場合のスケール係数です。デフォルトおよび推奨値は`1.0`です。 -* `--model_prediction_type=` - * モデルが何を予測するかを指定します。`raw` (予測値をそのまま使用), `additive` (ノイズ入力に加算), `sigma_scaled` (シグマスケーリングを適用) から選択します。デフォルトは `sigma_scaled` です。推奨は `raw` です。 -* `--discrete_flow_shift=` - * Flow Matchingで使用されるスケジューラのシフト値を指定します。デフォルトは`3.0`です。`timestep_sampling`に`flux_shift`を指定した場合は、この値は無視されます。 - -#### メモリ・速度関連 - -* `--blocks_to_swap=` **[実験的機能]** - * VRAM使用量を削減するために、モデルの一部(Transformerブロック)をCPUとGPU間でスワップする設定です。スワップするブロック数を整数で指定します(例: `18`)。値を大きくするとVRAM使用量は減りますが、学習速度は低下します。GPUのVRAM容量に応じて調整してください。`gradient_checkpointing`と併用可能です。 - * `--cpu_offload_checkpointing`とは併用できません。 -* `--cache_text_encoder_outputs` - * CLIP-LおよびT5-XXLの出力をキャッシュします。これにより、メモリ使用量が削減されます。 -* `--cache_latents`, `--cache_latents_to_disk` - * AEの出力をキャッシュします。[sdxl_train_network.py](sdxl_train_network.md)と同様の機能です。 - -#### 非互換・非推奨の引数 - -* `--v2`, `--v_parameterization`, `--clip_skip`: Stable Diffusion特有の引数のため、FLUX.1学習では使用されません。 -* `--max_token_length`: Stable Diffusion v1/v2向けの引数です。FLUX.1では`--t5xxl_max_token_length`を使用してください。 -* `--split_mode`: 非推奨の引数です。代わりに`--blocks_to_swap`を使用してください。 - -### 4.2. 学習の開始 - -必要な引数を設定し、コマンドを実行すると学習が開始されます。基本的な流れやログの確認方法は[`train_network.py`のガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。 - -## 5. 学習済みモデルの利用 - 学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_flux_lora.safetensors`)が保存されます。このファイルは、FLUX.1モデルに対応した推論環境(例: ComfyUI + ComfyUI-FluxNodes)で使用できます。 -## 6. その他 +
-`flux_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python flux_train_network.py --help`) を参照してください。 +## 6. Advanced Settings / 高度な設定 -# FLUX.1 LoRA学習の補足説明 +### 6.1. VRAM Usage Optimization / VRAM使用量の最適化 -以下は、以上の基本的なFLUX.1 LoRAの学習手順を補足するものです。より詳細な設定オプションなどについて説明します。 +FLUX.1 is a relatively large model, so GPUs without sufficient VRAM require optimization. Here are settings to reduce VRAM usage (with `--fp8_base`): -## 1. VRAM使用量の最適化 +#### Recommended Settings by GPU Memory -FLUX.1モデルは比較的大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。以下に、VRAM使用量を削減するための設定を紹介します。 +| GPU Memory | Recommended Settings | +|------------|---------------------| +| 24GB VRAM | Basic settings work fine (batch size 2) | +| 16GB VRAM | Set batch size to 1 and use `--blocks_to_swap` | +| 12GB VRAM | Use `--blocks_to_swap 16` and 8bit AdamW | +| 10GB VRAM | Use `--blocks_to_swap 22`, recommend fp8 format for T5XXL | +| 8GB VRAM | Use `--blocks_to_swap 28`, recommend fp8 format for T5XXL | -### 1.1 メモリ使用量別の推奨設定 +#### Key VRAM Reduction Options -| GPUメモリ | 推奨設定 | -|----------|----------| -| 24GB VRAM | 基本設定で問題なく動作します(バッチサイズ2) | -| 16GB VRAM | バッチサイズ1に設定し、`--blocks_to_swap`を使用 | -| 12GB VRAM | `--blocks_to_swap 16`と8bit AdamWを使用 | -| 10GB VRAM | `--blocks_to_swap 22`を使用、T5XXLはfp8形式を推奨 | -| 8GB VRAM | `--blocks_to_swap 28`を使用、T5XXLはfp8形式を推奨 | +- **`--fp8_base`**: Enables training in FP8 format. -### 1.2 主要なVRAM削減オプション +- **`--blocks_to_swap `**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. FLUX.1 supports up to 35 blocks for swapping. -- **`--blocks_to_swap <数値>`**: - CPUとGPU間でブロックをスワップしてVRAM使用量を削減します。数値が大きいほど多くのブロックをスワップし、より多くのVRAMを節約できますが、学習速度は低下します。FLUX.1では最大35ブロックまでスワップ可能です。 +- **`--cpu_offload_checkpointing`**: Offloads gradient checkpoints to CPU. Can reduce VRAM usage by up to 1GB but decreases training speed by about 15%. Cannot be used with `--blocks_to_swap`. Chroma models do not support this option. -- **`--cpu_offload_checkpointing`**: - 勾配チェックポイントをCPUにオフロードします。これにより最大1GBのVRAM使用量を削減できますが、学習速度は約15%低下します。`--blocks_to_swap`とは併用できません。 - -- **`--cache_text_encoder_outputs` / `--cache_text_encoder_outputs_to_disk`**: - CLIP-LとT5-XXLの出力をキャッシュします。これによりメモリ使用量を削減できます。 - -- **`--cache_latents` / `--cache_latents_to_disk`**: - AEの出力をキャッシュします。メモリ使用量を削減できます。 - -- **Adafactor オプティマイザの使用**: - 8bit AdamWよりもVRAM使用量を削減できます。以下の設定を使用してください: +- **Using Adafactor optimizer**: Can reduce VRAM usage more than 8bit AdamW: ``` --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0 ``` -- **T5XXLのfp8形式の使用**: - 10GB未満のVRAMを持つGPUでは、T5XXLのfp8形式チェックポイントの使用を推奨します。[comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders)から`t5xxl_fp8_e4m3fn.safetensors`をダウンロードできます(`scaled`なしで使用してください)。 +- **Using T5XXL fp8 format**: For GPUs with less than 10GB VRAM, using fp8 format T5XXL checkpoints is recommended. Download `t5xxl_fp8_e4m3fn.safetensors` from [comfyanonymous/flux_text_encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) (use without `scaled`). -- **FP8/FP16 混合学習 [実験的機能]**: - `--fp8_base_unet` オプションを指定すると、FLUX.1モデル本体をFP8形式で学習し、Text Encoder (CLIP-L/T5XXL) をBF16/FP16形式で学習できます。これにより、さらにVRAM使用量を削減できる可能性があります。このオプションを指定すると、`--fp8_base` オプションも自動的に有効になります。 +- **FP8/FP16 Mixed Training [Experimental]**: Specify `--fp8_base_unet` to train the FLUX.1 model in FP8 format while training Text Encoders (CLIP-L/T5XXL) in BF16/FP16 format. This can further reduce VRAM usage. -- **`pytorch-optimizer` の利用**: - `pytorch-optimizer` ライブラリに含まれる様々なオプティマイザを使用できます。`requirements.txt` に追加されているため、別途インストールは不要です。 - 例えば、CAME オプティマイザを使用する場合は以下のように指定します。 - ```bash - --optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01" - -## 2. FLUX.1 LoRA学習の重要な設定オプション +
+日本語 -FLUX.1の学習には多くの未知の点があり、いくつかの設定は引数で指定できます。以下に重要な引数とその説明を示します。 +FLUX.1モデルは比較的大きなモデルであるため、十分なVRAMを持たないGPUでは工夫が必要です。VRAM使用量を削減するための設定の詳細は英語のドキュメントを参照してください。 -### 2.1 タイムステップのサンプリング方法 +主要なVRAM削減オプション: +- `--fp8_base`: FP8形式での学習を有効化 +- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ +- `--cpu_offload_checkpointing`: 勾配チェックポイントをCPUにオフロード +- Adafactorオプティマイザの使用 +- T5XXLのfp8形式の使用 +- FP8/FP16混合学習(実験的機能) -`--timestep_sampling`オプションで、タイムステップ(0-1)のサンプリング方法を指定できます: +
-- `sigma`:SD3と同様のシグマベース -- `uniform`:一様ランダム -- `sigmoid`:正規分布乱数のシグモイド(x-flux、AI-toolkitなどと同様) -- `shift`:正規分布乱数のシグモイド値をシフト -- `flux_shift`:解像度に応じて正規分布乱数のシグモイド値をシフト(FLUX.1 dev推論と同様)。この設定では`--discrete_flow_shift`は無視されます。 +### 6.2. Important FLUX.1 LoRA Training Settings / FLUX.1 LoRA学習の重要な設定 +FLUX.1 training has many unknowns, and several settings can be specified with arguments: -#### タイムステップ分布の可視化 +#### Timestep Sampling Methods -`--timestep_sampling`, `--sigmoid_scale`, `--discrete_flow_shift` の組み合わせによって、学習中にサンプリングされるタイムステップの分布が変化します。以下にいくつかの例を示します。 +The `--timestep_sampling` option specifies how timesteps (0-1) are sampled: -* `--timestep_sampling shift` と `--discrete_flow_shift` の効果 (`--sigmoid_scale` はデフォルトの1.0): - ![Figure_2](https://github.com/user-attachments/assets/d9de42f9-f17d-40da-b88d-d964402569c6) +- `sigma`: Sigma-based like SD3 +- `uniform`: Uniform random +- `sigmoid`: Sigmoid of normal distribution random (similar to x-flux, AI-toolkit) +- `shift`: Sigmoid value of normal distribution random with shift. The `--discrete_flow_shift` setting is used to shift the sigmoid value. +- `flux_shift`: Shift sigmoid value of normal distribution random according to resolution (similar to FLUX.1 dev inference). -* `--timestep_sampling sigmoid` と `--timestep_sampling uniform` の比較 (`--discrete_flow_shift` は無視される): - ![Figure_3](https://github.com/user-attachments/assets/27029009-1f5d-4dc0-bb24-13d02ac4fdad) +`--discrete_flow_shift` only applies when `--timestep_sampling` is set to `shift`. -* `--timestep_sampling sigmoid` と `--sigmoid_scale` の効果 (`--discrete_flow_shift` は無視される): - ![Figure_4](https://github.com/user-attachments/assets/08a2267c-e47e-48b7-826e-f9a080787cdc) +#### Model Prediction Processing -#### AI Toolkit 設定との比較 +The `--model_prediction_type` option specifies how to interpret and process model predictions: -[Ostris氏のAI Toolkit](https://github.com/ostris/ai-toolkit) で使用されている設定は、概ね以下のオプションに相当すると考えられます。 -``` ---timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 1.0 -``` +- `raw`: Use as-is (similar to x-flux) **[Recommended]** +- `additive`: Add to noise input +- `sigma_scaled`: Apply sigma scaling (similar to SD3) -### 2.2 モデル予測の処理方法 +#### Recommended Settings -`--model_prediction_type`オプションで、モデルの予測をどのように解釈し処理するかを指定できます: - -- `raw`:そのまま使用(x-fluxと同様)【推奨】 -- `additive`:ノイズ入力に加算 -- `sigma_scaled`:シグマスケーリングを適用(SD3と同様) - -### 2.3 推奨設定 - -実験の結果、以下の設定が良好に動作することが確認されています: +Based on experiments, the following settings work well: ``` --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` -ガイダンススケールについて:FLUX.1 dev版は特定のガイダンススケール値で蒸留されていますが、学習時には`--guidance_scale 1.0`を指定してガイダンススケールを無効化することを推奨します。 +**About Guidance Scale**: FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `--guidance_scale 1.0` to disable guidance scale. +`--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 0.0` is recommended for Chroma models. -### 2.4 T5 Attention Mask の適用 +
+日本語 -`--apply_t5_attn_mask` オプションを指定すると、T5XXL Text Encoder の学習および推論時に Attention Mask が適用されます。 +FLUX.1の学習には多くの未知の点があり、いくつかの設定は引数で指定できます。詳細な説明とコマンドラインの例は英語のドキュメントを参照してください。 -Attention Maskに対応した推論環境が限られるため、このオプションは推奨されません。 +主要な設定オプション: +- タイムステップのサンプリング方法(`--timestep_sampling`) +- モデル予測の処理方法(`--model_prediction_type`) +- 推奨設定の組み合わせ -### 2.5 IP ノイズガンマ +
-`--ip_noise_gamma` および `--ip_noise_gamma_random_strength` オプションを使用することで、学習時に Input Perturbation ノイズのガンマ値を調整できます。詳細は Stable Diffusion 3 の学習オプションを参照してください。 +### 6.3. Layer-specific Rank Configuration / 各層に対するランク指定 -### 2.6 LoRA-GGPO サポート +You can specify different ranks (network_dim) for each layer of FLUX.1. This allows you to emphasize or disable LoRA effects for specific layers. -LoRA-GGPO (Gradient Group Proportion Optimizer) を使用できます。これは LoRA の学習を安定化させるための手法です。以下の `network_args` を指定して有効化します。ハイパーパラメータ (`ggpo_sigma`, `ggpo_beta`) は調整が必要です。 +Specify the following network_args to set ranks for each layer. Setting 0 disables LoRA for that layer: -```bash ---network_args "ggpo_sigma=0.03" "ggpo_beta=0.01" -``` -TOMLファイルで指定する場合: -```toml -network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"] -``` - -### 2.7 Q/K/V 射影層の分割 [実験的機能] - -`--network_args "split_qkv=True"` を指定することで、Attention層内の Q/K/V (および SingleStreamBlock の Text) 射影層を個別に分割し、それぞれに LoRA を適用できます。 - -**技術的詳細:** -FLUX.1 の元々の実装では、Q/K/V (および Text) の射影層は一つに結合されています。ここに LoRA を適用すると、一つの大きな LoRA モジュールが適用されます。一方、Diffusers の実装ではこれらの射影層は分離されており、それぞれに小さな LoRA モジュールが適用されます。このオプションは後者の挙動を模倣します。 -保存される LoRA モデルの互換性は維持されますが、内部的には分割された LoRA の重みを結合して保存するため、ゼロ要素が多くなりモデルサイズが大きくなる可能性があります。`convert_flux_lora.py` スクリプトを使用して Diffusers (AI-Toolkit) 形式に変換すると、サイズが削減されます。 - -## 3. 各層に対するランク指定 - -FLUX.1の各層に対して異なるランク(network_dim)を指定できます。これにより、特定の層に対してLoRAの効果を強調したり、無効化したりできます。 - -以下のnetwork_argsを指定することで、各層のランクを指定できます。0を指定するとその層にはLoRAが適用されません。 - -| network_args | 対象レイヤー | +| network_args | Target Layer | |--------------|--------------| -| img_attn_dim | DoubleStreamBlockのimg_attn | -| txt_attn_dim | DoubleStreamBlockのtxt_attn | -| img_mlp_dim | DoubleStreamBlockのimg_mlp | -| txt_mlp_dim | DoubleStreamBlockのtxt_mlp | -| img_mod_dim | DoubleStreamBlockのimg_mod | -| txt_mod_dim | DoubleStreamBlockのtxt_mod | -| single_dim | SingleStreamBlockのlinear1とlinear2 | -| single_mod_dim | SingleStreamBlockのmodulation | +| img_attn_dim | DoubleStreamBlock img_attn | +| txt_attn_dim | DoubleStreamBlock txt_attn | +| img_mlp_dim | DoubleStreamBlock img_mlp | +| txt_mlp_dim | DoubleStreamBlock txt_mlp | +| img_mod_dim | DoubleStreamBlock img_mod | +| txt_mod_dim | DoubleStreamBlock txt_mod | +| single_dim | SingleStreamBlock linear1 and linear2 | +| single_mod_dim | SingleStreamBlock modulation | -使用例: +Example usage: ``` --network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2" "img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2" ``` -さらに、FLUXの条件付けレイヤーにLoRAを適用するには、network_argsに`in_dims`を指定します。5つの数値をカンマ区切りのリストとして指定する必要があります。 +To apply LoRA to FLUX conditioning layers, specify `in_dims` in network_args as a comma-separated list of 5 numbers: -例: ``` --network_args "in_dims=[4,2,2,2,4]" ``` -各数値は、`img_in`、`time_in`、`vector_in`、`guidance_in`、`txt_in`に対応します。上記の例では、すべての条件付けレイヤーにLoRAを適用し、`img_in`と`txt_in`のランクを4、その他のランクを2に設定しています。 +Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The example above applies LoRA to all conditioning layers with ranks of 4 for `img_in` and `txt_in`, and ranks of 2 for others. -0を指定するとそのレイヤーにはLoRAが適用されません。例えば、`[4,0,0,0,4]`は`img_in`と`txt_in`にのみLoRAを適用します。 +
+日本語 -## 4. 学習するブロックの指定 +FLUX.1の各層に対して異なるランク(network_dim)を指定できます。これにより、特定の層に対してLoRAの効果を強調したり、無効化したりできます。 -FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_single_block_indices`を指定することで、学習するブロックを指定できます。インデックスは0ベースです。省略した場合のデフォルトはすべてのブロックを学習することです。 +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 -インデックスは、`0,1,5,8`のような整数のリストや、`0,1,4-5,7`のような整数の範囲として指定します。 -- double blocksの数は19なので、有効な範囲は0-18です -- single blocksの数は38なので、有効な範囲は0-37です -- `all`を指定するとすべてのブロックを学習します -- `none`を指定するとブロックを学習しません +
-使用例: +### 6.4. Block Selection for Training / 学習するブロックの指定 + +You can specify which blocks to train using `train_double_block_indices` and `train_single_block_indices` in network_args. Indices are 0-based. Default is to train all blocks if omitted. + +Specify indices as integer lists like `0,1,5,8` or integer ranges like `0,1,4-5,7`: +- Double blocks: 19 blocks, valid range 0-18 +- Single blocks: 38 blocks, valid range 0-37 +- Specify `all` to train all blocks +- Specify `none` to skip training blocks + +Example usage: ``` --network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37" ``` -または: +Or: ``` --network_args "train_double_block_indices=none" "train_single_block_indices=10-15" ``` -`train_double_block_indices`または`train_single_block_indices`のどちらか一方だけを指定した場合、もう一方は通常通り学習されます。 +
+日本語 -## 5. Text Encoder LoRAのサポート +FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_single_block_indices`を指定することで、学習するブロックを指定できます。 + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 6.5. Text Encoder LoRA Support / Text Encoder LoRAのサポート + +FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA: + +- To train only FLUX.1: specify `--network_train_unet_only` +- To train FLUX.1 and CLIP-L: omit `--network_train_unet_only` +- To train FLUX.1, CLIP-L, and T5XXL: omit `--network_train_unet_only` and add `--network_args "train_t5xxl=True"` + +You can specify individual learning rates for CLIP-L and T5XXL with `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5` sets the first value for CLIP-L and the second for T5XXL. Specifying one value uses the same learning rate for both. If `--text_encoder_lr` is not specified, the default `--learning_rate` is used for both. + +
+日本語 FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポートしています。 -- FLUX.1のみをトレーニングする場合は、`--network_train_unet_only`を指定します -- FLUX.1とCLIP-Lをトレーニングする場合は、`--network_train_unet_only`を省略します -- FLUX.1、CLIP-L、T5XXLすべてをトレーニングする場合は、`--network_train_unet_only`を省略し、`--network_args "train_t5xxl=True"`を追加します +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 -CLIP-LとT5XXLの学習率は、`--text_encoder_lr`で個別に指定できます。例えば、`--text_encoder_lr 1e-4 1e-5`とすると、最初の値はCLIP-Lの学習率、2番目の値はT5XXLの学習率になります。1つだけ指定すると、CLIP-LとT5XXLの学習率は同じになります。`--text_encoder_lr`を指定しない場合、デフォルトの学習率`--learning_rate`が両方に使用されます。 +
-## 6. マルチ解像度トレーニング +### 6.6. Multi-Resolution Training / マルチ解像度トレーニング -データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。 +You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution. -設定ファイルの例: +Configuration file example: ```toml [general] -# 共通設定をここで定義 +# Common settings flip_aug = true color_aug = false keep_tokens_separator= "|||" @@ -425,85 +433,151 @@ caption_tag_dropout_rate = 0 caption_extension = ".txt" [[datasets]] -# 最初の解像度の設定 +# First resolution settings batch_size = 2 enable_bucket = true resolution = [1024, 1024] [[datasets.subsets]] - image_dir = "画像ディレクトリへのパス" + image_dir = "path/to/image/directory" num_repeats = 1 [[datasets]] -# 2番目の解像度の設定 +# Second resolution settings batch_size = 3 enable_bucket = true resolution = [768, 768] [[datasets.subsets]] - image_dir = "画像ディレクトリへのパス" - num_repeats = 1 - -[[datasets]] -# 3番目の解像度の設定 -batch_size = 4 -enable_bucket = true -resolution = [512, 512] - - [[datasets.subsets]] - image_dir = "画像ディレクトリへのパス" + image_dir = "path/to/image/directory" num_repeats = 1 ``` -各解像度セクションの`[[datasets.subsets]]`部分は、データセットディレクトリを定義します。各解像度に対して同じディレクトリを指定してください。 +
+日本語 -## 7. 検証 (Validation) +データセット設定ファイルで複数の解像度を定義できます。各解像度に対して異なるバッチサイズを指定することができます。 -学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。 +設定ファイルの例は英語のドキュメントを参照してください。 -検証を設定するには、データセット設定 TOML ファイルに `[validation]` セクションを追加します。設定方法は学習データセットと同様ですが、`num_repeats` は通常 1 に設定します。 +
+ +### 6.7. Validation / 検証 + +You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. + +To set up validation, add a `[validation]` section to your dataset configuration TOML file. Configuration is similar to training datasets, but `num_repeats` is usually set to 1. ```toml -# ... (学習データセットの設定) ... +# ... (training dataset configuration) ... [validation] batch_size = 1 enable_bucket = true -resolution = [1024, 1024] # 検証に使用する解像度 +resolution = [1024, 1024] # Resolution for validation [[validation.subsets]] - image_dir = "検証用画像ディレクトリへのパス" + image_dir = "path/to/validation/images" num_repeats = 1 caption_extension = ".txt" - # ... 他の検証データセット固有の設定 ... + # ... other validation dataset settings ... ``` -**注意点:** +**Notes:** -* 検証損失の計算は、固定されたタイムステップサンプリングと乱数シードで行われます。これにより、ランダム性による損失の変動を抑え、より安定した評価が可能になります。 -* 現在のところ、`--blocks_to_swap` オプションを使用している場合、または Schedule-Free オプティマイザ (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`) を使用している場合は、検証損失はサポートされていません。 +* Validation loss calculation uses fixed timestep sampling and random seeds to reduce loss variation due to randomness for more stable evaluation. +* Currently, validation loss is not supported when using `--blocks_to_swap` or Schedule-Free optimizers (`AdamWScheduleFree`, `RAdamScheduleFree`, `ProdigyScheduleFree`). -## 8. データセット関連の追加オプション +
+日本語 -### 8.1 リサイズ時の補間方法指定 +学習中に検証データセットを使用して損失 (Validation Loss) を計算し、モデルの汎化性能を評価できます。 -データセットの画像を学習解像度にリサイズする際の補間方法を指定できます。データセット設定 TOML ファイルの `[[datasets]]` セクションまたは `[general]` セクションで `interpolation_type` を指定します。 +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 -利用可能な値: `bicubic` (デフォルト), `bilinear`, `lanczos`, `nearest`, `area` +
+ +## 7. Additional Options / 追加オプション + +### 7.1. Other FLUX.1-specific Options / その他のFLUX.1特有のオプション + +- **T5 Attention Mask Application**: Specify `--apply_t5_attn_mask` to apply attention masks during T5XXL Text Encoder training and inference. Not recommended due to limited inference environment support. **For Chroma models, this option is required.** + +- **IP Noise Gamma**: Use `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` to adjust Input Perturbation noise gamma values during training. See Stable Diffusion 3 training options for details. + +- **LoRA-GGPO Support**: Use LoRA-GGPO (Gradient Group Proportion Optimizer) to stabilize LoRA training: + ```bash + --network_args "ggpo_sigma=0.03" "ggpo_beta=0.01" + ``` + +- **Q/K/V Projection Layer Splitting [Experimental]**: Specify `--network_args "split_qkv=True"` to individually split and apply LoRA to Q/K/V (and SingleStreamBlock Text) projection layers within Attention layers. + +
+日本語 + +その他のFLUX.1特有のオプション: +- T5 Attention Maskの適用(Chromaモデルでは必須) +- IPノイズガンマ +- LoRA-GGPOサポート +- Q/K/V射影層の分割(実験的機能) + +詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 + +
+ +### 7.2. Dataset-related Additional Options / データセット関連の追加オプション + +#### Interpolation Method for Resizing + +You can specify the interpolation method when resizing dataset images to training resolution. Specify `interpolation_type` in the `[[datasets]]` or `[general]` section of the dataset configuration TOML file. + +Available values: `bicubic` (default), `bilinear`, `lanczos`, `nearest`, `area` ```toml [[datasets]] resolution = [1024, 1024] enable_bucket = true -interpolation_type = "lanczos" # 例: Lanczos補間を使用 +interpolation_type = "lanczos" # Example: Use Lanczos interpolation # ... ``` -## 9. 関連ツール +
+日本語 -`flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています。 +データセットの画像を学習解像度にリサイズする際の補間方法を指定できます。 -* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出します。 -* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など、他の形式に変換します。Q/K/V分割オプションで学習した場合、このスクリプトで変換するとモデルサイズを削減できます。 -* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージします。 -* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するためのシンプルな推論スクリプトです。 +設定方法とオプションの詳細は英語のドキュメントを参照してください。 + +
+ +## 8. Related Tools / 関連ツール + +Several related scripts are provided for models trained with `flux_train_network.py` and to assist with the training process: + +* **`networks/flux_extract_lora.py`**: Extracts LoRA models from the difference between trained and base models. +* **`convert_flux_lora.py`**: Converts trained LoRA models to other formats like Diffusers (AI-Toolkit) format. When trained with Q/K/V split option, converting with this script can reduce model size. +* **`networks/flux_merge_lora.py`**: Merges trained LoRA models into FLUX.1 base models. +* **`flux_minimal_inference.py`**: Simple inference script for generating images with trained LoRA models. You can specify `flux` or `chroma` with the `--model_type` argument. + +
+日本語 + +`flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています: + +* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出 +* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など他の形式に変換 +* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージ +* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するシンプルな推論スクリプト + +
+ +## 9. Others / その他 + +`flux_train_network.py` includes many features common with `train_network.py`, such as sample image generation (`--sample_prompts`, etc.) and detailed optimizer settings. For these features, refer to the [`train_network.py` guide](train_network.md#5-other-features--その他の機能) or the script help (`python flux_train_network.py --help`). + +
+日本語 + +`flux_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python flux_train_network.py --help`) を参照してください。 + +
From c28e7a47c3bd3c4efc81404bf4dadba2b41d4fe4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 26 Jul 2025 19:35:42 +0900 Subject: [PATCH 10/15] feat: add regex-based rank and learning rate configuration for FLUX.1 LoRA --- docs/flux_train_network.md | 49 ++++++++- networks/lora_flux.py | 197 +++++++++++++++++++++++++++---------- train_network.py | 2 +- 3 files changed, 194 insertions(+), 54 deletions(-) diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md index f324b959..1b584180 100644 --- a/docs/flux_train_network.md +++ b/docs/flux_train_network.md @@ -398,7 +398,50 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s -### 6.5. Text Encoder LoRA Support / Text Encoder LoRAのサポート + + + +### 6.4. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 + +You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer. + +These settings are specified via the `network_args` argument. + +* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`. + * Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`. +* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`. + * Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`. + +**Notes:** + +* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings. +* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied. +* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`). + +
+日本語 + +正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。 + +これらの設定は `network_args` 引数で指定します。 + +* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。 +* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。 +**注意点:** + +* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。 +* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。 +* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。 + +
+ +### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA: @@ -417,7 +460,7 @@ FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポート -### 6.6. Multi-Resolution Training / マルチ解像度トレーニング +### 6.7. Multi-Resolution Training / マルチ解像度トレーニング You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution. @@ -462,7 +505,7 @@ resolution = [768, 768] -### 6.7. Validation / 検証 +### 6.8. Validation / 検証 You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ddc91608..320bc463 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -156,11 +156,19 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + if ( + self.training + and self.ggpo_sigma is not None + and self.ggpo_beta is not None + and self.combined_weight_norms is not None + and self.grad_norms is not None + ): with torch.no_grad(): - perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + ( + self.ggpo_beta * (self.grad_norms**2) + ) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) - perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) perturbation.mul_(perturbation_scale_factor) perturbation_output = x @ perturbation.T # Result: (batch × n) return org_forwarded + (self.multiplier * scale * lx) + perturbation_output @@ -197,24 +205,24 @@ class LoRAModule(torch.nn.Module): # Choose a reasonable sample size n_rows = org_module_weight.shape[0] sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller - + # Sample random indices across all rows indices = torch.randperm(n_rows)[:sample_size] - + # Convert to a supported data type first, then index # Use float32 for indexing operations weights_float32 = org_module_weight.to(dtype=torch.float32) sampled_weights = weights_float32[indices].to(device=self.device) - + # Calculate sampled norms sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) - + # Store the mean norm as our estimate self.org_weight_norm_estimate = sampled_norms.mean() - + # Optional: store standard deviation for confidence intervals self.org_weight_norm_std = sampled_norms.std() - + # Free memory del sampled_weights, weights_float32 @@ -223,37 +231,36 @@ class LoRAModule(torch.nn.Module): # Calculate the true norm (this will be slow but it's just for validation) true_norms = [] chunk_size = 1024 # Process in chunks to avoid OOM - + for i in range(0, org_module_weight.shape[0], chunk_size): end_idx = min(i + chunk_size, org_module_weight.shape[0]) chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) chunk_norms = torch.norm(chunk, dim=1, keepdim=True) true_norms.append(chunk_norms.cpu()) del chunk - + true_norms = torch.cat(true_norms, dim=0) true_mean_norm = true_norms.mean().item() - + # Compare with our estimate estimated_norm = self.org_weight_norm_estimate.item() - + # Calculate error metrics absolute_error = abs(true_mean_norm - estimated_norm) relative_error = absolute_error / true_mean_norm * 100 # as percentage - + if verbose: logger.info(f"True mean norm: {true_mean_norm:.6f}") logger.info(f"Estimated norm: {estimated_norm:.6f}") logger.info(f"Absolute error: {absolute_error:.6f}") logger.info(f"Relative error: {relative_error:.2f}%") - - return { - 'true_mean_norm': true_mean_norm, - 'estimated_norm': estimated_norm, - 'absolute_error': absolute_error, - 'relative_error': relative_error - } + return { + "true_mean_norm": true_mean_norm, + "estimated_norm": estimated_norm, + "absolute_error": absolute_error, + "relative_error": relative_error, + } @torch.no_grad() def update_norms(self): @@ -261,7 +268,7 @@ class LoRAModule(torch.nn.Module): if self.ggpo_beta is None or self.ggpo_sigma is None: return - # only update norms when we are training + # only update norms when we are training if self.training is False: return @@ -269,8 +276,9 @@ class LoRAModule(torch.nn.Module): module_weights.mul(self.scale) self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) - self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) + - torch.sum(module_weights**2, dim=1, keepdim=True)) + self.combined_weight_norms = torch.sqrt( + (self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True) + ) @torch.no_grad() def update_grad_norms(self): @@ -293,7 +301,6 @@ class LoRAModule(torch.nn.Module): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) - @property def device(self): return next(self.parameters()).device @@ -564,7 +571,6 @@ def create_network( if ggpo_sigma is not None: ggpo_sigma = float(ggpo_sigma) - # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -575,6 +581,42 @@ def create_network( if verbose is not None: verbose = True if verbose == "True" else False + # regex-specific learning rates + def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: + """ + Parse a string of key-value pairs separated by commas. + """ + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + # parse regular expression based learning rates + network_reg_lrs = kwargs.get("network_reg_lrs", None) + if network_reg_lrs is not None: + reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False) + else: + reg_lrs = None + + # regex-specific dimensions (ranks) + network_reg_dims = kwargs.get("network_reg_dims", None) + if network_reg_dims is not None: + reg_dims = parse_kv_pairs(network_reg_dims, is_int=True) + else: + reg_dims = None + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -594,8 +636,10 @@ def create_network( in_dims=in_dims, train_double_block_indices=train_double_block_indices, train_single_block_indices=train_single_block_indices, + reg_dims=reg_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, + reg_lrs=reg_lrs, verbose=verbose, ) @@ -613,7 +657,6 @@ def create_network( # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): - # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -644,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh if train_t5xxl is None: train_t5xxl = False - # # split qkv - # double_qkv_rank = None - # single_qkv_rank = None - # rank = None - # for lora_name, dim in modules_dim.items(): - # if "double" in lora_name and "qkv" in lora_name: - # double_qkv_rank = dim - # elif "single" in lora_name and "linear1" in lora_name: - # single_qkv_rank = dim - # elif rank is None: - # rank = dim - # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: - # break - # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( - # single_qkv_rank is not None and single_qkv_rank != rank - # ) split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -708,8 +735,10 @@ class LoRANetwork(torch.nn.Module): in_dims: Optional[List[int]] = None, train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, + reg_dims: Optional[Dict[str, int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, + reg_lrs: Optional[Dict[str, float]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -730,6 +759,8 @@ class LoRANetwork(torch.nn.Module): self.in_dims = in_dims self.train_double_block_indices = train_double_block_indices self.train_single_block_indices = train_single_block_indices + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -757,7 +788,6 @@ class LoRANetwork(torch.nn.Module): if self.train_blocks is not None: logger.info(f"train {self.train_blocks} blocks only") - if train_t5xxl: logger.info(f"train T5XXL as well") @@ -803,8 +833,16 @@ class LoRANetwork(torch.nn.Module): if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする + elif self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.search(reg, lora_name): + dim = d + alpha = self.alpha + logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") + break + + # 通常、すべて対象とする + if dim is None: if is_linear or is_conv2d_1x1: dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha @@ -979,7 +1017,6 @@ class LoRANetwork(torch.nn.Module): combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None - def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -1166,17 +1203,77 @@ class LoRANetwork(torch.nn.Module): all_params = [] lr_descriptions = [] + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} + # regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...} + reg_groups = {} + for lora in loras: + # check if this lora matches any regex learning rate + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + try: + if re.search(regex_str, lora.lora_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + except re.error: + # regex error should have been caught during parsing, but just in case + continue + for name, param in lora.named_parameters(): - if loraplus_ratio is not None and "lora_up" in name: - param_groups["plus"][f"{lora.lora_name}.{name}"] = param + param_key = f"{lora.lora_name}.{name}" + is_plus = loraplus_ratio is not None and "lora_up" in name + + if matched_reg_lr is not None: + # use regex-specific learning rate + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + + if is_plus: + reg_groups[group_key]["plus"][param_key] = param + else: + reg_groups[group_key]["lora"][param_key] = param else: - param_groups["lora"][f"{lora.lora_name}.{name}"] = param + # use default learning rate + if is_plus: + param_groups["plus"][param_key] = param + else: + param_groups["lora"][param_key] = param params = [] descriptions = [] + + # process regex-specific groups first (higher priority) + for group_key in sorted(reg_groups.keys()): + group = reg_groups[group_key] + reg_lr = group["lr"] + + for param_type in ["lora", "plus"]: + if len(group[param_type]) == 0: + continue + + param_data = {"params": group[param_type].values()} + + if param_type == "plus" and loraplus_ratio is not None: + param_data["lr"] = reg_lr * loraplus_ratio + else: + param_data["lr"] = reg_lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + if param_type == "plus": + desc += " plus" + descriptions.append(desc) + + # process default groups for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} diff --git a/train_network.py b/train_network.py index 6073c4c3..7861e740 100644 --- a/train_network.py +++ b/train_network.py @@ -645,7 +645,7 @@ class NetworkTrainer: net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: - key, value = net_arg.split("=") + key, value = net_arg.split("=", 1) net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') From af14eab6d7f81493d23a7b961e01084f52eb5adf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 26 Jul 2025 19:37:15 +0900 Subject: [PATCH 11/15] doc: update section number for regex-based rank and learning rate configuration in FLUX.1 LoRA guide --- docs/flux_train_network.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md index 1b584180..647e87c9 100644 --- a/docs/flux_train_network.md +++ b/docs/flux_train_network.md @@ -401,7 +401,7 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s -### 6.4. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 +### 6.5. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer. From 6c8973c2da72fe9112729bdac9fc1ca21e06945c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 28 Jul 2025 22:08:02 +0900 Subject: [PATCH 12/15] doc: add reference link for input vector gradient requirement in Chroma class --- library/chroma_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/chroma_models.py b/library/chroma_models.py index b9c54db4..0c93f526 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -695,6 +695,7 @@ class Chroma(Flux): input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) # kohya-ss: I'm not sure why requires_grad is set to True here + # original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217 input_vec.requires_grad = True mod_vectors = self.distilled_guidance_layer(input_vec) else: From 450630c6bda18026c6017df088a8d73f89f67a60 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 29 Jul 2025 20:32:24 +0900 Subject: [PATCH 13/15] fix: create network from weights not working --- networks/lora_flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 320bc463..e9ad5f68 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -841,8 +841,8 @@ class LoRANetwork(torch.nn.Module): logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") break - # 通常、すべて対象とする - if dim is None: + # if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default) + if dim is None and modules_dim is None: if is_linear or is_conv2d_1x1: dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha From 96feb61c0a3d42f3526c09131090a33d2e5d8f23 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Jul 2025 21:34:49 +0900 Subject: [PATCH 14/15] feat: implement modulation vector extraction for Chroma and update related methods --- flux_minimal_inference.py | 3 +++ flux_train_network.py | 15 ++++++++------- library/chroma_models.py | 28 ++++++++++------------------ library/flux_models.py | 6 +++--- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 86e8e1b1..d5f2d8d9 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -113,6 +113,8 @@ def denoise( y_input = b_vec + mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0]) + pred = model( img=b_img, img_ids=b_img_ids, @@ -122,6 +124,7 @@ def denoise( timesteps=t_vec, guidance=guidance_vec, txt_attention_mask=b_t5_attn_mask, + mod_vectors=mod_vectors, ) # classifier free guidance diff --git a/flux_train_network.py b/flux_train_network.py index 13e9ae2a..2d9ab248 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -341,7 +341,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # get modulation vectors for Chroma - input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) + with accelerator.autocast(), torch.no_grad(): + mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) @@ -350,15 +351,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) - if input_vec is not None: - input_vec.requires_grad_(True) + if mod_vectors is not None: + mod_vectors.requires_grad_(True) # Predict the noise residual l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, input_vec): + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -371,7 +372,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps / 1000, guidance=guidance_vec, txt_attention_mask=t5_attn_mask, - input_vec=input_vec, + mod_vectors=mod_vectors, ) return model_pred @@ -384,7 +385,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps, guidance_vec=guidance_vec, t5_attn_mask=t5_attn_mask, - input_vec=input_vec, + mod_vectors=mod_vectors, ) # unpack latents @@ -416,7 +417,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): timesteps=timesteps[diff_output_pr_indices], guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, - input_vec=input_vec[diff_output_pr_indices] if input_vec is not None else None, + mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None, ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step diff --git a/library/chroma_models.py b/library/chroma_models.py index 0c93f526..d5ac1f39 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -641,7 +641,10 @@ class Chroma(Flux): print("Chroma: Gradient checkpointing disabled.") - def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + # We extract this logic from forward to clarify the propagation of the gradients + # original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195 + # print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}") distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority @@ -654,7 +657,9 @@ class Chroma(Flux): timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) # then and only then we could concatenate it together input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - return input_vec + + mod_vectors = self.distilled_guidance_layer(input_vec) + return mod_vectors def forward( self, @@ -669,7 +674,7 @@ class Chroma(Flux): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, attn_padding: int = 1, - input_vec: Tensor | None = None, + mod_vectors: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -684,22 +689,9 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - if input_vec is None: - # TODO: - # need to fix grad accumulation issue here for now it's in no grad mode - # besides, i don't want to wash out the PFP that's trained on this model weights anyway - # the fan out operation here is deleting the backward graph - # alternatively doing forward pass for every block manually is doable but slow - # custom backward probably be better + if mod_vectors is None: # fallback to the original logic with torch.no_grad(): - input_vec = self.get_input_vec(timesteps, guidance, img.shape[0]) - - # kohya-ss: I'm not sure why requires_grad is set to True here - # original code: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L217 - input_vec.requires_grad = True - mod_vectors = self.distilled_guidance_layer(input_vec) - else: - mod_vectors = self.distilled_guidance_layer(input_vec) + mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0]) mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) # calculate text length for each batch instead of masking diff --git a/library/flux_models.py b/library/flux_models.py index 63d699d4..d2d7e06c 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1009,8 +1009,8 @@ class Flux(nn.Module): self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: - return None # FLUX.1 does not use input_vec, but Chroma does. + def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + return None # FLUX.1 does not use mod_vectors, but Chroma does. def forward( self, @@ -1024,7 +1024,7 @@ class Flux(nn.Module): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, - input_vec: Tensor | None = None, + mod_vectors: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") From 250f0eb9b051784f6f18bb223ea88860119a0172 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 30 Jul 2025 22:08:51 +0900 Subject: [PATCH 15/15] doc: update README and training guide with breaking changes for CFG scale and model download instructions --- README.md | 4 +-- docs/flux_train_network.md | 57 ++++++++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9ba1cbfc..724bd3d8 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,8 @@ If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed ### Recent Updates -Jul XX, 2025: -- **Breaking Change**: For FLUX.1 and Chroma training, the CFG scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details. +Jul 30, 2025: +- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details. - Support for [Chroma](https://huggingface.co/lodestones/Chroma) has been added in PR [#2157](https://github.com/kohya-ss/sd-scripts/pull/2157). Thank you to lodestones for the high-quality model. - Chroma is a new model based on FLUX.1 schnell. In this repository, `flux_train_network.py` is used for training LoRAs for Chroma with `--model_type chroma`. diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md index 647e87c9..2bf3bfb2 100644 --- a/docs/flux_train_network.md +++ b/docs/flux_train_network.md @@ -71,6 +71,21 @@ Before starting training you need: 4. **AutoEncoder model file:** FLUX.1-compatible AE model `.safetensors` file (e.g., `ae.safetensors`). 5. **Dataset definition file (.toml):** TOML format file describing training dataset configuration (e.g., `my_flux_dataset_config.toml`). +### Downloading Required Models + +To train FLUX.1 models, you need to download the following model files: + +- **DiT, AE**: Download from the [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) repository. Use `flux1-dev.safetensors` and `ae.safetensors`. The weights in the subfolder are in Diffusers format and cannot be used. +- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: Download from the [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) repository. Please use `t5xxl_fp16.safetensors` for T5-XXL. Thanks to ComfyUI for providing these models. + +To train Chroma models, you need to download the Chroma model file from the following repository: + +- **Chroma Base**: Download from the [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) repository. Use `Chroma.safetensors`. + +We have tested Chroma training with the weights from the [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) repository. + +AE and T5-XXL models are same as FLUX.1, so you can use the same files. CLIP-L model is not used for Chroma training, so you can omit the `--clip_l` argument. +
日本語 @@ -84,6 +99,21 @@ Before starting training you need: 4. **AutoEncoderモデルファイル:** FLUX.1に対応するAEモデルの`.safetensors`ファイル。例として`ae.safetensors`を使用します。 5. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](link/to/dataset/config/doc)を参照してください)。例として`my_flux_dataset_config.toml`を使用します。 +**必要なモデルのダウンロード** + +FLUX.1モデルを学習するためには、以下のモデルファイルをダウンロードする必要があります。 + +- **DiT, AE**: [black-forest-labs/FLUX.1 dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) リポジトリからダウンロードします。`flux1-dev.safetensors`と`ae.safetensors`を使用してください。サブフォルダ内の重みはDiffusers形式であり、使用できません。 +- **Text Encoder 1 (T5-XXL), Text Encoder 2 (CLIP-L)**: [ComfyUI FLUX Text Encoders](https://huggingface.co/comfyanonymous/flux_text_encoders) リポジトリからダウンロードします。T5-XXLには`t5xxl_fp16.safetensors`を使用してください。これらのモデルを提供いただいたComfyUIに感謝します。 + +Chromaモデルを学習する場合は、以下のリポジトリからChromaモデルファイルをダウンロードする必要があります。 + +- **Chroma Base**: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base) リポジトリからダウンロードします。`Chroma.safetensors`を使用してください。 + +Chromaの学習のテストは [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) リポジトリの重みを使用して行いました。 + +AEとT5-XXLモデルはFLUX.1と同じものを使用できるため、同じファイルを使用します。CLIP-LモデルはChroma学習では使用されないため、`--clip_l`引数は省略できます。 +
## 4. Running the Training / 学習の実行 @@ -140,6 +170,12 @@ accelerate launch --num_cpu_threads_per_process 1 flux_train_network.py \ Note that for Chroma models, `--guidance_scale=0.0` is required to disable guidance scale, and `--apply_t5_attn_mask` is needed to apply attention masks for T5XXL Text Encoder. +The sample image generation during training requires specifying a negative prompt. Also, set `--g 0` to disable embedded guidance scale and `--l 4.0` to set the CFG scale. For example: + +``` +Japanese shrine in the summer forest. --n low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors --w 512 --h 512 --d 1 --l 4.0 --g 0.0 --s 20 +``` +
日本語 @@ -153,6 +189,8 @@ Chromaモデルを学習したい場合は、`--model_type=chroma`を指定し コマンドラインの例は英語のドキュメントを参照してください。 +学習中のサンプル画像生成には、ネガティブプロンプトを指定してください。また `--g 0` を指定して埋め込みガイダンススケールを無効化し、`--l 4.0` を指定してCFGスケールを設定します。 +
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説 @@ -314,9 +352,12 @@ Based on experiments, the following settings work well: --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 ``` -**About Guidance Scale**: FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `--guidance_scale 1.0` to disable guidance scale. +For Chroma models, the following settings are recommended: +``` +--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 0.0 +``` -`--timestep_sampling sigmoid --model_prediction_type raw --guidance_scale 0.0` is recommended for Chroma models. +**About Guidance Scale**: FLUX.1 dev version is distilled with specific guidance scale values, but for training, specify `--guidance_scale 1.0` to disable guidance scale. Chroma requires `--guidance_scale 0.0` to disable guidance scale because it is not distilled.
日本語 @@ -396,9 +437,6 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s 詳細な設定方法とコマンドラインの例は英語のドキュメントを参照してください。 -
- - ### 6.5. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 @@ -607,10 +645,11 @@ Several related scripts are provided for models trained with `flux_train_network `flux_train_network.py` で学習したモデルや、学習プロセスに役立つ関連スクリプトが提供されています: -* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出 -* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など他の形式に変換 -* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージ -* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するシンプルな推論スクリプト +* **`networks/flux_extract_lora.py`**: 学習済みモデルとベースモデルの差分から LoRA モデルを抽出。 +* **`convert_flux_lora.py`**: 学習した LoRA モデルを Diffusers (AI-Toolkit) 形式など他の形式に変換。 +* **`networks/flux_merge_lora.py`**: 学習した LoRA モデルを FLUX.1 ベースモデルにマージ。 +* **`flux_minimal_inference.py`**: 学習した LoRA モデルを適用して画像を生成するシンプルな推論スクリプト。 + `--model_type` 引数で `flux` または `chroma` を指定できます。