Merge branch 'sd3' into sdxl-ctrl-net

This commit is contained in:
Kohya S
2024-10-07 20:39:53 +09:00
7 changed files with 200 additions and 111 deletions

View File

@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Oct 6, 2024:
- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified.
- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format.
Sep 26, 2024: Sep 26, 2024:
The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details. The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.

View File

@@ -419,9 +419,6 @@ if __name__ == "__main__":
steps = args.steps steps = args.steps
guidance_scale = args.guidance guidance_scale = args.guidance
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
is_schnell = name == "schnell"
def is_fp8(dt): def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz] return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
@@ -455,12 +452,8 @@ if __name__ == "__main__":
# if is_fp8(t5xxl_dtype): # if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl) # t5xxl = accelerator.prepare(t5xxl)
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# DiT # DiT
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device) is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval() model.eval()
logger.info(f"Casting model to {flux_dtype}") logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype model.to(flux_dtype) # make sure model is dtype
@@ -469,8 +462,12 @@ if __name__ == "__main__":
# if args.offload: # if args.offload:
# model = model.to("cpu") # model = model.to("cpu")
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# AE # AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval() ae.eval()
# if is_fp8(ae_dtype): # if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae) # ae = accelerator.prepare(ae)

View File

@@ -137,6 +137,7 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認 train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
if args.debug_dataset: if args.debug_dataset:
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
@@ -144,9 +145,8 @@ def train(args):
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
) )
) )
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
t5xxl_max_token_length = ( t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512) args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
) )
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)) strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
@@ -177,12 +177,11 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
# load VAE for caching latents # load VAE for caching latents
ae = None ae = None
if cache_latents: if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype) ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False) ae.requires_grad_(False)
ae.eval() ae.eval()
@@ -196,7 +195,7 @@ def train(args):
# prepare tokenize strategy # prepare tokenize strategy
if args.t5xxl_max_token_length is None: if args.t5xxl_max_token_length is None:
if name == "schnell": if is_schnell:
t5xxl_max_token_length = 256 t5xxl_max_token_length = 256
else: else:
t5xxl_max_token_length = 512 t5xxl_max_token_length = 512
@@ -258,8 +257,8 @@ def train(args):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# load FLUX # load FLUX
flux = flux_utils.load_flow_model( _, flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
) )
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -294,7 +293,7 @@ def train(args):
if not cache_latents: if not cache_latents:
# load VAE here if not cached # load VAE here if not cached
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu") ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False) ae.requires_grad_(False)
ae.eval() ae.eval()
ae.to(accelerator.device, dtype=weight_dtype) ae.to(accelerator.device, dtype=weight_dtype)
@@ -706,7 +705,9 @@ def train(args):
accelerator.unwrap_model(flux).prepare_block_swap_before_forward() accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
optimizer_train_fn()
if len(accelerator.trackers) > 0: if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)

View File

@@ -2,7 +2,7 @@ import argparse
import copy import copy
import math import math
import random import random
from typing import Any from typing import Any, Optional
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
@@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sample_prompts_te_outputs = None self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
def assert_extra_args(self, args, train_dataset_group): def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group) super().assert_extra_args(args, train_dataset_group)
@@ -57,19 +58,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def get_flux_model_name(self, args):
return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models # currently offload to cpu for some models
name = self.get_flux_model_name(args)
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8) # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model( self.is_schnell, model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
) )
if args.fp8_base: if args.fp8_base:
# check dtype of model # check dtype of model
@@ -100,7 +97,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
elif t5xxl.dtype == torch.float8_e4m3fn: elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model") logger.info("Loaded fp8 T5XXL model")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
@@ -142,10 +139,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return flux_lower return flux_lower
def get_tokenize_strategy(self, args): def get_tokenize_strategy(self, args):
name = self.get_flux_model_name(args) _, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
if args.t5xxl_max_token_length is None: if args.t5xxl_max_token_length is None:
if name == "schnell": if is_schnell:
t5xxl_max_token_length = 256 t5xxl_max_token_length = 256
else: else:
t5xxl_max_token_length = 512 t5xxl_max_token_length = 512

View File

@@ -1,9 +1,11 @@
import json import json
from typing import Optional, Union import os
from typing import List, Optional, Tuple, Union
import einops import einops
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
@@ -17,6 +19,8 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_VERSION_FLUX_V1 = "flux1" MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
# temporary copy from sd3_utils TODO refactor # temporary copy from sd3_utils TODO refactor
@@ -39,10 +43,35 @@ def load_safetensors(
return load_file(path) # prevent device invalid Error return load_file(path) # prevent device invalid Error
def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]:
# check the state dict: Diffusers or BFL, dev or schnell
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = []
for ckpt_path in ckpt_paths:
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
return is_diffusers, is_schnell, ckpt_paths
def load_flow_model( def load_flow_model(
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux: ) -> Tuple[bool, flux_models.Flux]:
logger.info(f"Building Flux model {name}") is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"): with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params) model = flux_models.Flux(flux_models.configs[name].params)
if dtype is not None: if dtype is not None:
@@ -50,18 +79,28 @@ def load_flow_model(
# load_sft doesn't support torch.device # load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}") logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd)
logger.info("Converted Diffusers to BFL")
info = model.load_state_dict(sd, strict=False, assign=True) info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}") logger.info(f"Loaded Flux: {info}")
return model return is_schnell, model
def load_ae( def load_ae(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder: ) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder") logger.info("Building AutoEncoder")
with torch.device("meta"): with torch.device("meta"):
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype) # dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}") logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
@@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
""" """
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x return x
# region Diffusers
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(NUM_DOUBLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
return diffusers_to_bfl_map
def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map()
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for diffusers_key, tensor in diffusers_sd.items():
if diffusers_key in diffusers_to_bfl_map:
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
return flux_sd
# endregion

View File

@@ -29,6 +29,7 @@ from safetensors.torch import safe_open
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging() setup_logging()
@@ -36,65 +37,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def convert(args): def convert(args):
# if diffusers_path is folder, get safetensors file # if diffusers_path is folder, get safetensors file
@@ -114,23 +56,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map # make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key) diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
for b in range(NUM_DOUBLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
# iterate over three safetensors files to reduce memory usage # iterate over three safetensors files to reduce memory usage
flux_sd = {} flux_sd = {}

View File

@@ -1042,7 +1042,9 @@ class NetworkTrainer:
text_encoder = None text_encoder = None
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()
if len(accelerator.trackers) > 0: if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)