This commit is contained in:
Kohya S.
2025-10-16 02:56:09 +08:00
committed by GitHub
7 changed files with 222 additions and 107 deletions

View File

@@ -18,6 +18,7 @@ from safetensors.torch import load_file
from library import device_utils from library import device_utils
from library.device_utils import init_ipex, get_preferred_device from library.device_utils import init_ipex, get_preferred_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from networks import oft_flux from networks import oft_flux
init_ipex() init_ipex()
@@ -325,7 +326,8 @@ def generate_image(
# generate image # generate image
logger.info("Generating image...") logger.info("Generating image...")
model = model.to(device) if args.offload and not (args.blocks_to_swap is not None and args.blocks_to_swap > 0):
model = model.to(device)
if steps is None: if steps is None:
steps = 4 if is_schnell else 50 steps = 4 if is_schnell else 50
@@ -411,12 +413,16 @@ if __name__ == "__main__":
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae") parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl") parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux") parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for flux model")
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--blocks_to_swap", type=int, default=None, help="Number of blocks to swap between CPU and GPU to reduce memory usage"
)
parser.add_argument( parser.add_argument(
"--lora_weights", "--lora_weights",
type=str, type=str,
@@ -442,6 +448,8 @@ if __name__ == "__main__":
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype) t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype) ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype) flux_dtype = str_to_dtype(args.flux_dtype, dtype)
if args.fp8_scaled and flux_dtype.itemsize == 1:
raise ValueError("fp8_scaled is not supported for fp8 flux_dtype")
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}") logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
@@ -470,13 +478,68 @@ if __name__ == "__main__":
# if is_fp8(t5xxl_dtype): # if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl) # t5xxl = accelerator.prepare(t5xxl)
# check LoRA and OFT weights can be mergeable
mergeable_lora_weights = None
mergeable_lora_multipliers = None
if args.fp8_scaled and args.lora_weights:
assert args.merge_lora_weights, "LoRA weights must be merged when using fp8_scaled"
mergeable_lora_weights = []
mergeable_lora_multipliers = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
with MemoryEfficientSafeOpen(weights_file) as f:
keys = list(f.keys())
is_lora = is_oft = False
includes_text_encoder = False
for key in keys:
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if key.startswith("lora_te") or key.startswith("oft_te"):
includes_text_encoder = True
if (is_lora or is_oft) and includes_text_encoder:
break
if includes_text_encoder or is_oft:
raise ValueError(
f"LoRA weights {weights_file} that includes text encoder or OFT weights cannot be merged when using fp8_scaled"
)
mergeable_lora_weights.append(weights_file)
mergeable_lora_multipliers.append(multiplier)
# DiT # DiT
loading_dtype = None if args.fp8_scaled else flux_dtype
loading_device = "cpu" if args.blocks_to_swap or args.offload else device
is_schnell, model = flux_utils.load_flow_model( is_schnell, model = flux_utils.load_flow_model(
args.ckpt_path, None, loading_device, disable_mmap=True, model_type=args.model_type device,
args.ckpt_path,
loading_dtype,
loading_device,
args.model_type,
args.fp8_scaled,
lora_weights_list=mergeable_lora_weights,
lora_multipliers=mergeable_lora_multipliers,
) )
model.eval() model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype if args.blocks_to_swap is not None and args.blocks_to_swap > 0:
model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=False)
model.move_to_device_except_swap_blocks(device)
model.prepare_block_swap_before_forward()
# logger.info(f"Casting model to {flux_dtype}")
# model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype): # if is_fp8(flux_dtype):
# model = accelerator.prepare(model) # model = accelerator.prepare(model)
# if args.offload: # if args.offload:
@@ -494,36 +557,37 @@ if __name__ == "__main__":
# LoRA # LoRA
lora_models: List[lora_flux.LoRANetwork] = [] lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights: if not args.fp8_scaled: # LoRA cannot be applied after fp8 scaling and quantization
if ";" in weights_file: for weights_file in args.lora_weights:
weights_file, multiplier = weights_file.split(";") if ";" in weights_file:
multiplier = float(multiplier) weights_file, multiplier = weights_file.split(";")
else: multiplier = float(multiplier)
multiplier = 1.0 else:
multiplier = 1.0
weights_sd = load_file(weights_file) weights_sd = load_file(weights_file)
is_lora = is_oft = False is_lora = is_oft = False
for key in weights_sd.keys(): for key in weights_sd.keys():
if key.startswith("lora"): if key.startswith("lora"):
is_lora = True is_lora = True
if key.startswith("oft"): if key.startswith("oft"):
is_oft = True is_oft = True
if is_lora or is_oft: if is_lora or is_oft:
break break
module = lora_flux if is_lora else oft_flux module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True) lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
if args.merge_lora_weights: if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd) lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else: else:
lora_model.apply_to([clip_l, t5xxl], model) lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True) info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}") logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval() lora_model.eval()
lora_model.to(device) lora_model.to(device)
lora_models.append(lora_model) lora_models.append(lora_model)
if not args.interactive: if not args.interactive:
generate_image( generate_image(

View File

@@ -271,7 +271,7 @@ def train(args):
# load FLUX # load FLUX
_, flux = flux_utils.load_flow_model( _, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
) )
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -302,7 +302,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you! # This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device) flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if not cache_latents: if not cache_latents:
# load VAE here if not cached # load VAE here if not cached

View File

@@ -265,7 +265,7 @@ def train(args):
# load FLUX # load FLUX
is_schnell, flux = flux_utils.load_flow_model( is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux" accelerator.device, args.pretrained_model_name_or_path, weight_dtype, "cpu", model_type="flux"
) )
flux.requires_grad_(False) flux.requires_grad_(False)
@@ -304,7 +304,7 @@ def train(args):
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you! # This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device) flux.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
# ControlNet only has two blocks, so we can keep it on GPU # ControlNet only has two blocks, so we can keep it on GPU
# controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device) # controlnet.enable_block_swap(args.blocks_to_swap, accelerator.device)

View File

@@ -51,7 +51,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.use_clip_l = True self.use_clip_l = True
else: else:
self.use_clip_l = False # Chroma does not use CLIP-L self.use_clip_l = False # Chroma does not use CLIP-L
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります" assert (
args.apply_t5_attn_mask
), "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"
if args.fp8_base_unet: if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
@@ -100,17 +102,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# currently offload to cpu for some models # currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future # load with quantization if needed
_, model = flux_utils.load_flow_model( _, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, accelerator.device, args.pretrained_model_name_or_path, loading_dtype, loading_device, self.model_type, args.fp8_scaled
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
) )
if args.fp8_base:
if args.fp8_base and not args.fp8_scaled:
# check dtype of model # check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
@@ -130,7 +130,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if self.is_swapping_blocks: if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device) model.enable_block_swap(args.blocks_to_swap, accelerator.device, supports_backward=True)
if self.use_clip_l: if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
@@ -309,6 +309,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def shift_scale_latents(self, args, latents): def shift_scale_latents(self, args, latents):
return latents return latents
def cast_unet(self, args):
return not args.fp8_scaled
def get_noise_pred_and_target( def get_noise_pred_and_target(
self, self,
args, args,
@@ -525,6 +528,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_dit_training_arguments(parser) train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser) flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument( parser.add_argument(
"--split_mode", "--split_mode",
action="store_true", action="store_true",

View File

@@ -691,25 +691,32 @@ class DoubleStreamBlock(nn.Module):
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
del vec
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img.to(torch.float32)).to(img.dtype)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt.to(torch.float32)).to(txt.dtype)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
q = torch.cat((txt_q, img_q), dim=2) q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2) k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2) v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# make attention mask if not None # make attention mask if not None
attn_mask = None attn_mask = None
@@ -725,14 +732,24 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
del q, k, v, attn
# calculate the img blocks # calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) del img_mod1, img_attn
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img.to(torch.float32)).to(img.dtype) + img_mod2.shift
)
del img_mod2
# calculate the txt blocks # calculate the txt blocks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) del txt_mod1, txt_attn
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt.to(torch.float32)).to(txt.dtype) + txt_mod2.shift
)
del txt_mod2
return img, txt return img, txt
def forward( def forward(
@@ -805,10 +822,14 @@ class SingleStreamBlock(nn.Module):
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift del vec
x_mod = (1 + mod.scale) * self.pre_norm(x.to(torch.float32)) + mod.shift
x_mod = x_mod.to(x.dtype)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del qkv
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# make attention mask if not None # make attention mask if not None
@@ -831,9 +852,12 @@ class SingleStreamBlock(nn.Module):
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
del attn, mlp
return x + mod.gate * output return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
@@ -969,7 +993,7 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.") print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2 double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -980,10 +1004,10 @@ class Flux(nn.Module):
) )
self.offloader_double = custom_offloading_utils.ModelOffloader( self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
self.offloader_single = custom_offloading_utils.ModelOffloader( self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
print( print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1215,7 +1239,7 @@ class ControlNetFlux(nn.Module):
print("FLUX: Gradient checkpointing disabled.") print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool = False):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
double_blocks_to_swap = num_blocks // 2 double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
@@ -1226,10 +1250,10 @@ class ControlNetFlux(nn.Module):
) )
self.offloader_double = custom_offloading_utils.ModelOffloader( self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, double_blocks_to_swap, device # , debug=True self.double_blocks, double_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
self.offloader_single = custom_offloading_utils.ModelOffloader( self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, single_blocks_to_swap, device # , debug=True self.single_blocks, single_blocks_to_swap, device, supports_backward=supports_backward # , debug=True
) )
print( print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."

View File

@@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import einops import einops
import torch import torch
@@ -10,6 +10,8 @@ from safetensors import safe_open
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
@@ -25,6 +27,9 @@ MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell" MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma" MODEL_VERSION_CHROMA = "chroma"
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"]
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
""" """
@@ -93,17 +98,23 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
def load_flow_model( def load_flow_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device], device: Union[str, torch.device],
disable_mmap: bool = False, ckpt_path: str,
dit_weight_dtype: Optional[torch.dtype],
loading_device: Union[str, torch.device],
model_type: str = "flux", model_type: str = "flux",
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> Tuple[bool, flux_models.Flux]: ) -> Tuple[bool, flux_models.Flux]:
device = torch.device(device) # device for calculation, typically "cuda"
loading_device = torch.device(loading_device)
# build model
if model_type == "flux": if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint") logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"): with torch.device("meta"):
params = flux_models.configs[name].params params = flux_models.configs[name].params
@@ -117,31 +128,8 @@ def load_flow_model(
params = replace(params, depth_single_blocks=num_single_blocks) params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params) model = flux_models.Flux(params)
if dtype is not None: if dit_weight_dtype is not None:
model = model.to(dtype) model = model.to(dit_weight_dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
elif model_type == "chroma": elif model_type == "chroma":
from . import chroma_models from . import chroma_models
@@ -150,28 +138,59 @@ def load_flow_model(
logger.info("Building Chroma model") logger.info("Building Chroma model")
with torch.device("meta"): with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params) model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None: if dit_weight_dtype is not None:
model = model.to(dtype) model = model.to(dit_weight_dtype)
# load_sft doesn't support torch.device ckpt_paths = [ckpt_path]
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
# if the key has annoying prefix, remove it # load state dict
for key in list(sd.keys()): logger.info(f"Loading DiT model from {ckpt_paths}, device={loading_device}")
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
sd = load_safetensors_with_lora_and_fp8(
model_files=ckpt_paths,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
)
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu": # in case of no block swapping
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
if model_type == "flux":
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
elif model_type == "chroma":
info = model.load_state_dict(sd, strict=False, assign=True) info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}") logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell is_schnell = False # Chroma is not schnell
return is_schnell, model return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")
def load_ae( def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False

View File

@@ -824,27 +824,31 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.") accelerator.print("enable full bf16 training.")
network.to(weight_dtype) network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype unet_weight_dtype = weight_dtype
te_weight_dtype = weight_dtype if self.cast_text_encoder(args) else None
# Experimental Feature: Put base model into fp8 to save vram # Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base or args.fp8_base_unet: if args.fp8_base or args.fp8_base_unet:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert ( assert (
args.mixed_precision != "no" args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training for U-Net.") if self.cast_unet(args):
unet_weight_dtype = torch.float8_e4m3fn accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
if not args.fp8_base_unet: if self.cast_text_encoder(args):
accelerator.print("enable fp8 training for Text Encoder.") if not args.fp8_base_unet:
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") if self.cast_unet(args):
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
unet.requires_grad_(False) unet.requires_grad_(False)
if self.cast_unet(args): if self.cast_unet(args):