mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Compare commits
5 Commits
e1bf05ec67
...
feat-flux-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5585347a52 | ||
|
|
b9357c662d | ||
|
|
08dc0858b5 | ||
|
|
f9710863ca | ||
|
|
f6137a7175 |
@@ -18,6 +18,7 @@ from safetensors.torch import load_file
|
||||
|
||||
from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from networks import oft_flux
|
||||
|
||||
init_ipex()
|
||||
@@ -325,7 +326,8 @@ def generate_image(
|
||||
|
||||
# generate image
|
||||
logger.info("Generating image...")
|
||||
model = model.to(device)
|
||||
if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0):
|
||||
model = model.to(device)
|
||||
if steps is None:
|
||||
steps = 4 if is_schnell else 50
|
||||
|
||||
@@ -411,12 +413,16 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
||||
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
||||
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model")
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
||||
parser.add_argument("--guidance", type=float, default=3.5)
|
||||
parser.add_argument("--negative_prompt", type=str, default=None)
|
||||
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
@@ -442,6 +448,8 @@ if __name__ == "__main__":
|
||||
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
||||
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
||||
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
||||
if args.fp8_scaled and flux_dtype.itemsize == 1:
|
||||
raise ValueError("fp8_scaled is not supported for fp8 flux_dtype")
|
||||
|
||||
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
||||
|
||||
@@ -470,13 +478,68 @@ if __name__ == "__main__":
|
||||
# if is_fp8(t5xxl_dtype):
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
# check LoRA and OFT weights can be mergeable
|
||||
mergeable_lora_weights = None
|
||||
mergeable_lora_multipliers = None
|
||||
if args.fp8_scaled and args.lora_weights:
|
||||
assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled"
|
||||
|
||||
mergeable_lora_weights = []
|
||||
mergeable_lora_multipliers = []
|
||||
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
with MemoryEfficientSafeOpen(weights_file) as f:
|
||||
keys = list(f.keys())
|
||||
|
||||
is_lora = is_oft = False
|
||||
includes_text_encoder = False
|
||||
for key in keys:
|
||||
if key.startswith("lora"):
|
||||
is_lora = True
|
||||
if key.startswith("oft"):
|
||||
is_oft = True
|
||||
if key.startswith("lora_te") or key.startswith("oft_te"):
|
||||
includes_text_encoder = True
|
||||
if (is_lora or is_oft) and includes_text_encoder:
|
||||
break
|
||||
|
||||
if includes_text_encoder or is_oft:
|
||||
raise ValueError(
|
||||
f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled"
|
||||
)
|
||||
|
||||
mergeable_lora_weights.append(weights_file)
|
||||
mergeable_lora_multipliers.append(multiplier)
|
||||
|
||||
# DiT
|
||||
loading_dtype = None if args.fp8_scaled else flux_dtype
|
||||
loading_device = "cpu" if args.blocks_to_swap or args.offload else device
|
||||
|
||||
is_schnell, model = flux_utils.load_flow_model(
|
||||
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type
|
||||
device,
|
||||
args.ckpt_path,
|
||||
loading_dtype,
|
||||
loading_device,
|
||||
args.model_type,
|
||||
args.fp8_scaled,
|
||||
lora_weights_list=mergeable_lora_weights,
|
||||
lora_multipliers=mergeable_lora_multipliers,
|
||||
)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
|
||||
if args.blocks_to_swap is not None and args.blocks_to_swap > 0:
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False)
|
||||
model.move_to_device_except_swap_blocks(device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
# logger.info(f"Casting model to {flux_dtype}")
|
||||
# model.to(flux_dtype) # make sure model is dtype
|
||||
# if is_fp8(flux_dtype):
|
||||
# model = accelerator.prepare(model)
|
||||
# if args.offload:
|
||||
@@ -494,36 +557,37 @@ if __name__ == "__main__":
|
||||
|
||||
# LoRA
|
||||
lora_models: List[lora_flux.LoRANetwork] = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
is_lora = is_oft = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith("lora"):
|
||||
is_lora = True
|
||||
if key.startswith("oft"):
|
||||
is_oft = True
|
||||
if is_lora or is_oft:
|
||||
break
|
||||
weights_sd = load_file(weights_file)
|
||||
is_lora = is_oft = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith("lora"):
|
||||
is_lora = True
|
||||
if key.startswith("oft"):
|
||||
is_oft = True
|
||||
if is_lora or is_oft:
|
||||
break
|
||||
|
||||
module = lora_flux if is_lora else oft_flux
|
||||
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
||||
module = lora_flux if is_lora else oft_flux
|
||||
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([clip_l, t5xxl], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
lora_model.to(device)
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([clip_l, t5xxl], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
lora_model.to(device)
|
||||
|
||||
lora_models.append(lora_model)
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
|
||||
@@ -271,7 +271,7 @@ def train(args):
|
||||
|
||||
# load FLUX
|
||||
_, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -302,7 +302,7 @@ def train(args):
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
|
||||
|
||||
if not cache_latents:
|
||||
# load VAE here if not cached
|
||||
|
||||
@@ -265,7 +265,7 @@ def train(args):
|
||||
|
||||
# load FLUX
|
||||
is_schnell, flux = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
|
||||
accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
|
||||
@@ -304,7 +304,7 @@ def train(args):
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
# This idea is based on 2kpr's great work. Thank you!
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
|
||||
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
# ControlNet only has two blocks, so we can keep it on GPU
|
||||
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
@@ -51,7 +51,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.use_clip_l = True
|
||||
else:
|
||||
self.use_clip_l = False # Chroma does not use CLIP-L
|
||||
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
|
||||
assert (
|
||||
args.apply_t5_attn_mask
|
||||
), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
|
||||
|
||||
if args.fp8_base_unet:
|
||||
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
|
||||
@@ -99,18 +101,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
|
||||
# load with quantization if needed
|
||||
_, model = flux_utils.load_flow_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
loading_dtype,
|
||||
"cpu",
|
||||
disable_mmap=args.disable_mmap_load_safetensors,
|
||||
model_type=self.model_type,
|
||||
accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled
|
||||
)
|
||||
if args.fp8_base:
|
||||
|
||||
if args.fp8_base and not args.fp8_scaled:
|
||||
# check dtype of model
|
||||
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
|
||||
@@ -125,12 +127,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# if args.split_mode:
|
||||
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
|
||||
|
||||
if self.use_clip_l:
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
@@ -309,6 +309,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def cast_unet(self, args):
|
||||
return not args.fp8_scaled
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
@@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument(
|
||||
"--split_mode",
|
||||
action="store_true",
|
||||
|
||||
@@ -691,25 +691,32 @@ class DoubleStreamBlock(nn.Module):
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
del vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = self.img_norm1(img.to(torch.float32)).to(img.dtype)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = self.txt_norm1(txt.to(torch.float32)).to(txt.dtype)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
|
||||
# make attention mask if not None
|
||||
attn_mask = None
|
||||
@@ -725,14 +732,24 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
del q, k, v, attn
|
||||
|
||||
# calculate the img blocks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
del img_mod1, img_attn
|
||||
img = img + img_mod2.gate * self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img.to(torch.float32)).to(img.dtype) + img_mod2.shift
|
||||
)
|
||||
del img_mod2
|
||||
|
||||
# calculate the txt blocks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
del txt_mod1, txt_attn
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt.to(torch.float32)).to(txt.dtype) + txt_mod2.shift
|
||||
)
|
||||
del txt_mod2
|
||||
|
||||
return img, txt
|
||||
|
||||
def forward(
|
||||
@@ -805,10 +822,14 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
del vec
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x.to(torch.float32)) + mod.shift
|
||||
x_mod = x_mod.to(x.dtype)
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# make attention mask if not None
|
||||
@@ -831,9 +852,12 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
del q, k, v
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
del attn, mlp
|
||||
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
@@ -969,7 +993,7 @@ class Flux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
|
||||
self.blocks_to_swap = num_blocks
|
||||
double_blocks_to_swap = num_blocks // 2
|
||||
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||
@@ -980,10 +1004,10 @@ class Flux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1215,7 +1239,7 @@ class ControlNetFlux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
|
||||
self.blocks_to_swap = num_blocks
|
||||
double_blocks_to_swap = num_blocks // 2
|
||||
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||
@@ -1226,10 +1250,10 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
@@ -10,6 +10,8 @@ from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
|
||||
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -25,6 +27,9 @@ MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
MODEL_VERSION_CHROMA = "chroma"
|
||||
|
||||
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
|
||||
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"]
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
@@ -93,17 +98,23 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
ckpt_path: str,
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
loading_device: Union[str, torch.device],
|
||||
model_type: str = "flux",
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> Tuple[bool, flux_models.Flux]:
|
||||
device = torch.device(device) # device for calculation, typically "cuda"
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
# build model
|
||||
if model_type == "flux":
|
||||
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
|
||||
# build model
|
||||
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
|
||||
with torch.device("meta"):
|
||||
params = flux_models.configs[name].params
|
||||
@@ -117,31 +128,8 @@ def load_flow_model(
|
||||
params = replace(params, depth_single_blocks=num_single_blocks)
|
||||
|
||||
model = flux_models.Flux(params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = {}
|
||||
for ckpt_path in ckpt_paths:
|
||||
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
if dit_weight_dtype is not None:
|
||||
model = model.to(dit_weight_dtype)
|
||||
|
||||
elif model_type == "chroma":
|
||||
from . import chroma_models
|
||||
@@ -150,28 +138,59 @@ def load_flow_model(
|
||||
logger.info("Building Chroma model")
|
||||
with torch.device("meta"):
|
||||
model = chroma_models.Chroma(chroma_models.chroma_params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
if dit_weight_dtype is not None:
|
||||
model = model.to(dit_weight_dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
ckpt_paths = [ckpt_path]
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
# load state dict
|
||||
logger.info(f"Loading DiT model from {ckpt_paths}, device={loading_device}")
|
||||
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=ckpt_paths,
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=fp8_scaled,
|
||||
calc_device=device,
|
||||
move_to_device=(loading_device == device),
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
)
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
if fp8_scaled:
|
||||
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
||||
|
||||
if loading_device.type != "cpu": # in case of no block swapping
|
||||
# make sure all the model weights are on the loading_device
|
||||
logger.info(f"Moving weights to {loading_device}")
|
||||
for key in sd.keys():
|
||||
sd[key] = sd[key].to(loading_device)
|
||||
|
||||
if model_type == "flux":
|
||||
# convert Diffusers to BFL
|
||||
if is_diffusers:
|
||||
logger.info("Converting Diffusers to BFL")
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
|
||||
elif model_type == "chroma":
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Chroma: {info}")
|
||||
is_schnell = False # Chroma is not schnell
|
||||
return is_schnell, model
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
|
||||
|
||||
|
||||
def load_ae(
|
||||
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
|
||||
@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
original_dtype = value.dtype
|
||||
|
||||
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
|
||||
logger.warning(
|
||||
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
|
||||
"Loading checkpoint as-is without re-quantization."
|
||||
)
|
||||
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
|
||||
value = value.to(target_device)
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Move to calculation device
|
||||
if calc_device is not None:
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
|
||||
@@ -824,27 +824,31 @@ class NetworkTrainer:
|
||||
accelerator.print("enable full bf16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
unet_weight_dtype = weight_dtype
|
||||
te_weight_dtype = weight_dtype if self.cast_text_encoder(args) else None
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != "no"
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
if self.cast_unet(args):
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
if self.cast_text_encoder(args):
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
||||
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
|
||||
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
||||
if self.cast_unet(args):
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
|
||||
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
||||
|
||||
unet.requires_grad_(False)
|
||||
if self.cast_unet(args):
|
||||
|
||||
Reference in New Issue
Block a user