Compare commits

...

5 Commits

Author SHA1 Message Date
Kohya S
5585347a52 Merge branch 'sd3' into feat-flux-chroma-fp8-scaled 2026-01-18 16:58:01 +09:00
Kohya S
b9357c662d Merge branch 'sd3' into feat-flux-chroma-fp8-scaled 2026-01-18 14:37:12 +09:00
Kohya S.
08dc0858b5 Merge pull request #2224 from rockerBOO/block-swap-fp8-scaled
Block swap and fp8 scaled, fp8 already quantized warning
2026-01-18 14:35:32 +09:00
rockerBOO
f9710863ca Set is_swapping_blocks before loading_device, add warning for ignoring fp8_scaled if already fp8 2025-10-10 15:58:21 -04:00
Kohya S
f6137a7175 feat: add fp8 optimization for FLUX 2025-09-25 22:10:11 +09:00
8 changed files with 237 additions and 111 deletions

View File

@@ -18,6 +18,7 @@ from safetensors.torch import load_file
from library import device_utils from library import device_utils
from library.device_utils import init_ipex, get_preferred_device from library.device_utils import init_ipex, get_preferred_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from networks import oft_flux from networks import oft_flux
init_ipex() init_ipex()
@@ -325,6 +326,7 @@ def generate_image(
# generate image # generate image
logger.info("Generating image...") logger.info("Generating image...")
if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0):
model = model.to(device) model = model.to(device)
if steps is None: if steps is None:
steps = 4 if is_schnell else 50 steps = 4 if is_schnell else 50
@@ -411,12 +413,16 @@ if __name__ == "__main__":
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model")
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage"
)
parser.add_argument( parser.add_argument(
"--lora_weights", "--lora_weights",
type=str, type=str,
@@ -442,6 +448,8 @@ if __name__ == "__main__":
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype) ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype) flux_dtype = str_to_dtype(args.flux_dtype, dtype)
if args.fp8_scaled and flux_dtype.itemsize == 1:
raise ValueError("fp8_scaled is not supported for fp8 flux_dtype")
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
@@ -470,13 +478,68 @@ if __name__ == "__main__":
# if is_fp8(t5xxl_dtype): # if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl) # t5xxl = accelerator.prepare(t5xxl)
# check LoRA and OFT weights can be mergeable
mergeable_lora_weights = None
mergeable_lora_multipliers = None
if args.fp8_scaled and args.lora_weights:
assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled"
mergeable_lora_weights = []
mergeable_lora_multipliers = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
with MemoryEfficientSafeOpen(weights_file) as f:
keys = list(f.keys())
is_lora = is_oft = False
includes_text_encoder = False
for key in keys:
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if key.startswith("lora_te") or key.startswith("oft_te"):
includes_text_encoder = True
if (is_lora or is_oft) and includes_text_encoder:
break
if includes_text_encoder or is_oft:
raise ValueError(
f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled"
)
mergeable_lora_weights.append(weights_file)
mergeable_lora_multipliers.append(multiplier)
# DiT # DiT
loading_dtype = None if args.fp8_scaled else flux_dtype
loading_device = "cpu" if args.blocks_to_swap or args.offload else device
is_schnell, model = flux_utils.load_flow_model( is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type device,
args.ckpt_path,
loading_dtype,
loading_device,
args.model_type,
args.fp8_scaled,
lora_weights_list=mergeable_lora_weights,
lora_multipliers=mergeable_lora_multipliers,
) )
model.eval() model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype if args.blocks_to_swap is not None and args.blocks_to_swap > 0:
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False)
model.move_to_device_except_swap_blocks(device)
model.prepare_block_swap_before_forward()
# logger.info(f"Casting model to {flux_dtype}")
# model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype): # if is_fp8(flux_dtype):
# model = accelerator.prepare(model) # model = accelerator.prepare(model)
# if args.offload: # if args.offload:
@@ -494,6 +557,7 @@ if __name__ == "__main__":
# LoRA # LoRA
lora_models: List[lora_flux.LoRANetwork] = [] lora_models: List[lora_flux.LoRANetwork] = []
if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization
for weights_file in args.lora_weights: for weights_file in args.lora_weights:
if ";" in weights_file: if ";" in weights_file:
weights_file, multiplier = weights_file.split(";") weights_file, multiplier = weights_file.split(";")

View File

@@ -271,7 +271,7 @@ def train(args):
# load FLUX # load FLUX
_, flux = flux_utils.load_flow_model( _, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
) )
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -302,7 +302,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you! # This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device) flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if not cache_latents: if not cache_latents:
# load VAE here if not cached # load VAE here if not cached

View File

@@ -265,7 +265,7 @@ def train(args):
# load FLUX # load FLUX
is_schnell, flux = flux_utils.load_flow_model( is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
) )
flux.requires_grad_(False) flux.requires_grad_(False)
@@ -304,7 +304,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you! # This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device) flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU # ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)

View File

@@ -51,7 +51,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.use_clip_l = True self.use_clip_l = True
else: else:
self.use_clip_l = False # Chroma does not use CLIP-L self.use_clip_l = False # Chroma does not use CLIP-L
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります" assert (
args.apply_t5_attn_mask
), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
if args.fp8_base_unet: if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
@@ -99,18 +101,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models # currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8) self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
# load with quantization if needed
_, model = flux_utils.load_flow_model( _, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
) )
if args.fp8_base:
if args.fp8_base and not args.fp8_scaled:
# check dtype of model # check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
@@ -125,12 +127,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if args.split_mode: # if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator) # model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks: if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device) model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if self.use_clip_l: if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
@@ -309,6 +309,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def shift_scale_latents(self, args, latents): def shift_scale_latents(self, args, latents):
return latents return latents
def cast_unet(self, args):
return not args.fp8_scaled
def get_noise_pred_and_target( def get_noise_pred_and_target(
self, self,
args, args,
@@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_dit_training_arguments(parser) train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser) flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument( parser.add_argument(
"--split_mode", "--split_mode",
action="store_true", action="store_true",

View File

@@ -691,25 +691,32 @@ class DoubleStreamBlock(nn.Module):
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
del vec
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img.to(torch.float32)).to(img.dtype)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt.to(torch.float32)).to(txt.dtype)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
q = torch.cat((txt_q, img_q), dim=2) q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2) k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2) v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# make attention mask if not None # make attention mask if not None
attn_mask = None attn_mask = None
@@ -725,14 +732,24 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
del q, k, v, attn
# calculate the img blocks # calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) del img_mod1, img_attn
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img.to(torch.float32)).to(img.dtype) + img_mod2.shift
)
del img_mod2
# calculate the txt blocks # calculate the txt blocks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) del txt_mod1, txt_attn
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt.to(torch.float32)).to(txt.dtype) + txt_mod2.shift
)
del txt_mod2
return img, txt return img, txt
def forward( def forward(
@@ -805,10 +822,14 @@ class SingleStreamBlock(nn.Module):
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift del vec
x_mod = (1 + mod.scale) * self.pre_norm(x.to(torch.float32)) + mod.shift
x_mod = x_mod.to(x.dtype)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del qkv
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# make attention mask if not None # make attention mask if not None
@@ -831,9 +852,12 @@ class SingleStreamBlock(nn.Module):
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
del attn, mlp
return x + mod.gate * output return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
@@ -969,7 +993,7 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.") print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2 double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -980,10 +1004,10 @@ class Flux(nn.Module):
) )
self.offloader_double = custom_offloading_utils.ModelOffloader( self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
self.offloader_single = custom_offloading_utils.ModelOffloader( self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
print( print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1215,7 +1239,7 @@ class ControlNetFlux(nn.Module):
print("FLUX: Gradient checkpointing disabled.") print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2 double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -1226,10 +1250,10 @@ class ControlNetFlux(nn.Module):
) )
self.offloader_double = custom_offloading_utils.ModelOffloader( self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
self.offloader_single = custom_offloading_utils.ModelOffloader( self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
print( print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."

View File

@@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import einops import einops
import torch import torch
@@ -10,6 +10,8 @@ from safetensors import safe_open
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
@@ -25,6 +27,9 @@ MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell" MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma" MODEL_VERSION_CHROMA = "chroma"
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"]
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
""" """
@@ -93,17 +98,23 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
def load_flow_model( def load_flow_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device], device: Union[str, torch.device],
disable_mmap: bool = False, ckpt_path: str,
dit_weight_dtype: Optional[torch.dtype],
loading_device: Union[str, torch.device],
model_type: str = "flux", model_type: str = "flux",
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> Tuple[bool, flux_models.Flux]: ) -> Tuple[bool, flux_models.Flux]:
device = torch.device(device) # device for calculation, typically "cuda"
loading_device = torch.device(loading_device)
# build model
if model_type == "flux": if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"): with torch.device("meta"):
params = flux_models.configs[name].params params = flux_models.configs[name].params
@@ -117,31 +128,8 @@ def load_flow_model(
params = replace(params, depth_single_blocks=num_single_blocks) params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params) model = flux_models.Flux(params)
if dtype is not None: if dit_weight_dtype is not None:
model = model.to(dtype) model = model.to(dit_weight_dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
elif model_type == "chroma": elif model_type == "chroma":
from . import chroma_models from . import chroma_models
@@ -150,12 +138,25 @@ def load_flow_model(
logger.info("Building Chroma model") logger.info("Building Chroma model")
with torch.device("meta"): with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params) model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None: if dit_weight_dtype is not None:
model = model.to(dtype) model = model.to(dit_weight_dtype)
# load_sft doesn't support torch.device ckpt_paths = [ckpt_path]
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) # load state dict
logger.info(f"Loading DiT model from {ckpt_paths}, device={loading_device}")
sd = load_safetensors_with_lora_and_fp8(
model_files=ckpt_paths,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
)
# if the key has annoying prefix, remove it # if the key has annoying prefix, remove it
for key in list(sd.keys()): for key in list(sd.keys()):
@@ -164,14 +165,32 @@ def load_flow_model(
break # the model doesn't have annoying prefix break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key) sd[new_key] = sd.pop(key)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu": # in case of no block swapping
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
if model_type == "flux":
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
elif model_type == "chroma":
info = model.load_state_dict(sd, strict=False, assign=True) info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}") logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell is_schnell = False # Chroma is not schnell
return is_schnell, model return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
def load_ae( def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False

View File

@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
state_dict[key] = value state_dict[key] = value
continue continue
original_dtype = value.dtype
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
logger.warning(
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
"Loading checkpoint as-is without re-quantization."
)
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device # Move to calculation device
if calc_device is not None: if calc_device is not None:
value = value.to(calc_device) value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight( quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
) )

View File

@@ -824,16 +824,19 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.") accelerator.print("enable full bf16 training.")
network.to(weight_dtype) network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype unet_weight_dtype = weight_dtype
te_weight_dtype = weight_dtype if self.cast_text_encoder(args) else None
# Experimental Feature: Put base model into fp8 to save vram # Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base or args.fp8_base_unet: if args.fp8_base or args.fp8_base_unet:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert ( assert (
args.mixed_precision != "no" args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
if self.cast_unet(args):
accelerator.print("enable fp8 training for U-Net.") accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn unet_weight_dtype = torch.float8_e4m3fn
if self.cast_text_encoder(args):
if not args.fp8_base_unet: if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.") accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
@@ -843,6 +846,7 @@ class NetworkTrainer:
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
if self.cast_unet(args):
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator