mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
add FLUX.1 LoRA training
This commit is contained in:
20
README.md
20
README.md
@@ -1,5 +1,25 @@
|
|||||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||||
|
|
||||||
|
## FLUX.1 LoRA training (WIP)
|
||||||
|
|
||||||
|
__Aug 9, 2024__:
|
||||||
|
|
||||||
|
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
||||||
|
|
||||||
|
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options.
|
||||||
|
|
||||||
|
```
|
||||||
|
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name
|
||||||
|
```
|
||||||
|
|
||||||
|
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||||
|
|
||||||
|
```
|
||||||
|
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
Unfortnately the training result is not good. Please let us know if you have any idea to improve the training.
|
||||||
|
|
||||||
## SD3 training
|
## SD3 training
|
||||||
|
|
||||||
SD3 training is done with `sd3_train.py`.
|
SD3 training is done with `sd3_train.py`.
|
||||||
|
|||||||
390
flux_minimal_inference.py
Normal file
390
flux_minimal_inference.py
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
# Minimum Inference Code for FLUX
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import safe_open, load_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import accelerate
|
||||||
|
|
||||||
|
from library import device_utils
|
||||||
|
from library.device_utils import init_ipex, get_preferred_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import networks.lora_flux as lora_flux
|
||||||
|
from library import flux_models, flux_utils, sd3_utils, strategy_flux
|
||||||
|
|
||||||
|
|
||||||
|
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||||
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||||
|
m = (y2 - y1) / (x2 - x1)
|
||||||
|
b = y1 - m * x1
|
||||||
|
return lambda x: m * x + b
|
||||||
|
|
||||||
|
|
||||||
|
def get_schedule(
|
||||||
|
num_steps: int,
|
||||||
|
image_seq_len: int,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
shift: bool = True,
|
||||||
|
) -> list[float]:
|
||||||
|
# extra step for zero
|
||||||
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||||
|
|
||||||
|
# shifting the schedule to favor high timesteps for higher signal images
|
||||||
|
if shift:
|
||||||
|
# eastimate mu based on linear estimation between two points
|
||||||
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||||
|
timesteps = time_shift(mu, 1.0, timesteps)
|
||||||
|
|
||||||
|
return timesteps.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def denoise(
|
||||||
|
model: flux_models.Flux,
|
||||||
|
img: torch.Tensor,
|
||||||
|
img_ids: torch.Tensor,
|
||||||
|
txt: torch.Tensor,
|
||||||
|
txt_ids: torch.Tensor,
|
||||||
|
vec: torch.Tensor,
|
||||||
|
timesteps: list[float],
|
||||||
|
guidance: float = 4.0,
|
||||||
|
):
|
||||||
|
# this is ignored for schnell
|
||||||
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
pred = model(img=img, img_ids=img_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec)
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(
|
||||||
|
accelerator: Optional[accelerate.Accelerator],
|
||||||
|
model: flux_models.Flux,
|
||||||
|
img: torch.Tensor,
|
||||||
|
img_ids: torch.Tensor,
|
||||||
|
l_pooled: torch.Tensor,
|
||||||
|
t5_out: torch.Tensor,
|
||||||
|
txt_ids: torch.Tensor,
|
||||||
|
num_steps: int,
|
||||||
|
guidance: float,
|
||||||
|
is_schnell: bool,
|
||||||
|
device: torch.device,
|
||||||
|
flux_dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
||||||
|
|
||||||
|
# denoise initial noise
|
||||||
|
if accelerator:
|
||||||
|
with accelerator.autocast(), torch.no_grad():
|
||||||
|
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
||||||
|
else:
|
||||||
|
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
||||||
|
x = denoise(model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image(
|
||||||
|
model,
|
||||||
|
clip_l,
|
||||||
|
t5xxl,
|
||||||
|
ae,
|
||||||
|
prompt: str,
|
||||||
|
seed: Optional[int],
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
steps: Optional[int],
|
||||||
|
guidance: float,
|
||||||
|
):
|
||||||
|
# make first noise with packed shape
|
||||||
|
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||||
|
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||||
|
noise = torch.randn(
|
||||||
|
1,
|
||||||
|
packed_latent_height * packed_latent_width,
|
||||||
|
16 * 2 * 2,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare img and img ids
|
||||||
|
|
||||||
|
# this is needed only for img2img
|
||||||
|
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
# if img.shape[0] == 1 and bs > 1:
|
||||||
|
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||||
|
|
||||||
|
# txt2img only needs img_ids
|
||||||
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
|
# prepare embeddings
|
||||||
|
logger.info("Encoding prompts...")
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
|
clip_l = clip_l.to(device)
|
||||||
|
t5xxl = t5xxl.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype):
|
||||||
|
clip_l.to(clip_l_dtype)
|
||||||
|
t5xxl.to(t5xxl_dtype)
|
||||||
|
with accelerator.autocast():
|
||||||
|
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||||
|
l_pooled, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||||
|
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||||
|
_, t5_out, txt_ids = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# NaN check
|
||||||
|
if torch.isnan(l_pooled).any():
|
||||||
|
raise ValueError("NaN in l_pooled")
|
||||||
|
if torch.isnan(t5_out).any():
|
||||||
|
raise ValueError("NaN in t5_out")
|
||||||
|
|
||||||
|
if args.offload:
|
||||||
|
clip_l = clip_l.cpu()
|
||||||
|
t5xxl = t5xxl.cpu()
|
||||||
|
# del clip_l, t5xxl
|
||||||
|
device_utils.clean_memory()
|
||||||
|
|
||||||
|
# generate image
|
||||||
|
logger.info("Generating image...")
|
||||||
|
model = model.to(device)
|
||||||
|
if steps is None:
|
||||||
|
steps = 4 if is_schnell else 50
|
||||||
|
|
||||||
|
img_ids = img_ids.to(device)
|
||||||
|
x = do_sample(
|
||||||
|
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype
|
||||||
|
)
|
||||||
|
if args.offload:
|
||||||
|
model = model.cpu()
|
||||||
|
# del model
|
||||||
|
device_utils.clean_memory()
|
||||||
|
|
||||||
|
# unpack
|
||||||
|
x = x.float()
|
||||||
|
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||||
|
|
||||||
|
# decode
|
||||||
|
logger.info("Decoding image...")
|
||||||
|
ae = ae.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
if is_fp8(ae_dtype):
|
||||||
|
with accelerator.autocast():
|
||||||
|
x = ae.decode(x)
|
||||||
|
else:
|
||||||
|
with torch.autocast(device_type=device.type, dtype=ae_dtype):
|
||||||
|
x = ae.decode(x)
|
||||||
|
if args.offload:
|
||||||
|
ae = ae.cpu()
|
||||||
|
|
||||||
|
x = x.clamp(-1, 1)
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||||
|
|
||||||
|
# save image
|
||||||
|
output_dir = args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||||||
|
img.save(output_path)
|
||||||
|
|
||||||
|
logger.info(f"Saved image to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
target_height = 768 # 1024
|
||||||
|
target_width = 1360 # 1024
|
||||||
|
|
||||||
|
# steps = 50 # 28 # 50
|
||||||
|
# guidance_scale = 5
|
||||||
|
# seed = 1 # None # 1
|
||||||
|
|
||||||
|
device = get_preferred_device()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
|
parser.add_argument("--clip_l", type=str, required=False)
|
||||||
|
parser.add_argument("--t5xxl", type=str, required=False)
|
||||||
|
parser.add_argument("--ae", type=str, required=False)
|
||||||
|
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||||
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
||||||
|
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
||||||
|
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
|
||||||
|
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
|
||||||
|
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
|
||||||
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
|
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
||||||
|
parser.add_argument("--guidance", type=float, default=3.5)
|
||||||
|
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora_weights",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--width", type=int, default=target_width)
|
||||||
|
parser.add_argument("--height", type=int, default=target_height)
|
||||||
|
parser.add_argument("--interactive", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed = args.seed
|
||||||
|
steps = args.steps
|
||||||
|
guidance_scale = args.guidance
|
||||||
|
|
||||||
|
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
|
||||||
|
is_schnell = name == "schnell"
|
||||||
|
|
||||||
|
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||||
|
if s is None:
|
||||||
|
return default_dtype
|
||||||
|
if s in ["bf16", "bfloat16"]:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif s in ["fp16", "float16"]:
|
||||||
|
return torch.float16
|
||||||
|
elif s in ["fp32", "float32"]:
|
||||||
|
return torch.float32
|
||||||
|
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||||||
|
return torch.float8_e4m3fnuz
|
||||||
|
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||||||
|
return torch.float8_e5m2
|
||||||
|
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||||||
|
return torch.float8_e5m2fnuz
|
||||||
|
elif s in ["fp8", "float8"]:
|
||||||
|
return torch.float8_e4m3fn # default fp8
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dtype: {s}")
|
||||||
|
|
||||||
|
def is_fp8(dt):
|
||||||
|
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||||
|
|
||||||
|
dtype = str_to_dtype(args.dtype)
|
||||||
|
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
|
||||||
|
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
|
||||||
|
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
|
||||||
|
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
|
||||||
|
|
||||||
|
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
|
||||||
|
|
||||||
|
loading_device = "cpu" if args.offload else device
|
||||||
|
|
||||||
|
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
|
||||||
|
if any(use_fp8):
|
||||||
|
accelerator = accelerate.Accelerator(mixed_precision="bf16")
|
||||||
|
else:
|
||||||
|
accelerator = None
|
||||||
|
|
||||||
|
# load clip_l
|
||||||
|
logger.info(f"Loading clip_l from {args.clip_l}...")
|
||||||
|
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
|
||||||
|
clip_l.eval()
|
||||||
|
|
||||||
|
logger.info(f"Loading t5xxl from {args.t5xxl}...")
|
||||||
|
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||||
|
t5xxl.eval()
|
||||||
|
|
||||||
|
if is_fp8(clip_l_dtype):
|
||||||
|
clip_l = accelerator.prepare(clip_l)
|
||||||
|
if is_fp8(t5xxl_dtype):
|
||||||
|
t5xxl = accelerator.prepare(t5xxl)
|
||||||
|
|
||||||
|
t5xxl_max_length = 256 if is_schnell else 512
|
||||||
|
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||||
|
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
||||||
|
|
||||||
|
# DiT
|
||||||
|
model = flux_utils.load_flow_model(name, args.ckpt_path, flux_dtype, loading_device)
|
||||||
|
model.eval()
|
||||||
|
logger.info(f"Casting model to {flux_dtype}")
|
||||||
|
model.to(flux_dtype) # make sure model is dtype
|
||||||
|
if is_fp8(flux_dtype):
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
|
||||||
|
# AE
|
||||||
|
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
|
||||||
|
ae.eval()
|
||||||
|
if is_fp8(ae_dtype):
|
||||||
|
ae = accelerator.prepare(ae)
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
for weights_file in args.lora_weights:
|
||||||
|
if ";" in weights_file:
|
||||||
|
weights_file, multiplier = weights_file.split(";")
|
||||||
|
multiplier = float(multiplier)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
lora_model, weights_sd = lora_flux.create_network_from_weights(
|
||||||
|
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
|
||||||
|
)
|
||||||
|
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||||
|
|
||||||
|
if not args.interactive:
|
||||||
|
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
|
||||||
|
else:
|
||||||
|
# loop for interactive
|
||||||
|
width = target_width
|
||||||
|
height = target_height
|
||||||
|
steps = None
|
||||||
|
guidance = args.guidance
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance>")
|
||||||
|
prompt = input()
|
||||||
|
if prompt == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
# parse options
|
||||||
|
options = prompt.split("--")
|
||||||
|
prompt = options[0].strip()
|
||||||
|
seed = None
|
||||||
|
for opt in options[1:]:
|
||||||
|
opt = opt.strip()
|
||||||
|
if opt.startswith("w"):
|
||||||
|
width = int(opt[1:].strip())
|
||||||
|
elif opt.startswith("h"):
|
||||||
|
height = int(opt[1:].strip())
|
||||||
|
elif opt.startswith("s"):
|
||||||
|
steps = int(opt[1:].strip())
|
||||||
|
elif opt.startswith("d"):
|
||||||
|
seed = int(opt[1:].strip())
|
||||||
|
elif opt.startswith("g"):
|
||||||
|
guidance = float(opt[1:].strip())
|
||||||
|
|
||||||
|
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)
|
||||||
|
|
||||||
|
logger.info("Done!")
|
||||||
332
flux_train_network.py
Normal file
332
flux_train_network.py
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util
|
||||||
|
import train_network
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||||
|
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(32)
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
# currently offload to cpu for some models
|
||||||
|
|
||||||
|
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
|
||||||
|
clip_l.eval()
|
||||||
|
|
||||||
|
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||||
|
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
|
||||||
|
t5xxl.eval()
|
||||||
|
|
||||||
|
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
|
||||||
|
# if we load to cpu, flux.to(fp8) takes a long time
|
||||||
|
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||||
|
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
||||||
|
|
||||||
|
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||||
|
|
||||||
|
def get_tokenize_strategy(self, args):
|
||||||
|
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
|
||||||
|
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
||||||
|
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
||||||
|
|
||||||
|
def get_latents_caching_strategy(self, args):
|
||||||
|
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
||||||
|
return latents_caching_strategy
|
||||||
|
|
||||||
|
def get_text_encoding_strategy(self, args):
|
||||||
|
return strategy_flux.FluxTextEncodingStrategy()
|
||||||
|
|
||||||
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
|
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
||||||
|
|
||||||
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_text_encoder_outputs_if_needed(
|
||||||
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||||
|
):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
if not args.lowram:
|
||||||
|
# メモリ消費を減らす
|
||||||
|
logger.info("move vae and unet to cpu to save memory")
|
||||||
|
org_vae_device = vae.device
|
||||||
|
org_unet_device = unet.device
|
||||||
|
vae.to("cpu")
|
||||||
|
unet.to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
|
logger.info("move text encoders to gpu")
|
||||||
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
with accelerator.autocast():
|
||||||
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
logger.info("move text encoders back to cpu")
|
||||||
|
text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
|
text_encoders[1].to("cpu") # , dtype=torch.float32)
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
if not args.lowram:
|
||||||
|
logger.info("move vae and unet back to original device")
|
||||||
|
vae.to(org_vae_device)
|
||||||
|
unet.to(org_unet_device)
|
||||||
|
else:
|
||||||
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||||
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
|
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
|
# # get size embeddings
|
||||||
|
# orig_size = batch["original_sizes_hw"]
|
||||||
|
# crop_size = batch["crop_top_lefts"]
|
||||||
|
# target_size = batch["target_sizes_hw"]
|
||||||
|
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
|
# # concat embeddings
|
||||||
|
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
||||||
|
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||||
|
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||||
|
|
||||||
|
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||||
|
# return noise_pred
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
|
# logger.warning("Sampling images is not supported for Flux model")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||||
|
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
||||||
|
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||||
|
return noise_scheduler
|
||||||
|
|
||||||
|
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||||
|
return vae.encode(images).latent_dist.sample()
|
||||||
|
|
||||||
|
def shift_scale_latents(self, args, latents):
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def get_noise_pred_and_target(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
noise_scheduler,
|
||||||
|
latents,
|
||||||
|
batch,
|
||||||
|
text_encoder_conds,
|
||||||
|
unet: flux_models.Flux,
|
||||||
|
network,
|
||||||
|
weight_dtype,
|
||||||
|
train_unet,
|
||||||
|
):
|
||||||
|
# copy from sd3_train.py and modified
|
||||||
|
|
||||||
|
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||||
|
sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||||
|
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device)
|
||||||
|
timesteps = timesteps.to(accelerator.device)
|
||||||
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
|
sigma = sigmas[step_indices].flatten()
|
||||||
|
while len(sigma.shape) < n_dim:
|
||||||
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
return sigma
|
||||||
|
|
||||||
|
def compute_density_for_timestep_sampling(
|
||||||
|
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||||
|
):
|
||||||
|
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||||
|
|
||||||
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||||
|
|
||||||
|
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||||
|
"""
|
||||||
|
if weighting_scheme == "logit_normal":
|
||||||
|
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||||
|
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||||
|
u = torch.nn.functional.sigmoid(u)
|
||||||
|
elif weighting_scheme == "mode":
|
||||||
|
u = torch.rand(size=(batch_size,), device="cpu")
|
||||||
|
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||||
|
else:
|
||||||
|
u = torch.rand(size=(batch_size,), device="cpu")
|
||||||
|
return u
|
||||||
|
|
||||||
|
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||||
|
"""Computes loss weighting scheme for SD3 training.
|
||||||
|
|
||||||
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||||
|
|
||||||
|
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||||
|
"""
|
||||||
|
if weighting_scheme == "sigma_sqrt":
|
||||||
|
weighting = (sigmas**-2.0).float()
|
||||||
|
elif weighting_scheme == "cosmap":
|
||||||
|
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||||
|
weighting = 2 / (math.pi * bot)
|
||||||
|
else:
|
||||||
|
weighting = torch.ones_like(sigmas)
|
||||||
|
return weighting
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
|
u = compute_density_for_timestep_sampling(
|
||||||
|
weighting_scheme=args.weighting_scheme,
|
||||||
|
batch_size=bsz,
|
||||||
|
logit_mean=args.logit_mean,
|
||||||
|
logit_std=args.logit_std,
|
||||||
|
mode_scale=args.mode_scale,
|
||||||
|
)
|
||||||
|
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
||||||
|
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
||||||
|
|
||||||
|
# Add noise according to flow matching.
|
||||||
|
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||||
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||||
|
|
||||||
|
# pack latents and get img_ids
|
||||||
|
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||||
|
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||||
|
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||||
|
|
||||||
|
# get guidance
|
||||||
|
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
||||||
|
|
||||||
|
# ensure the hidden state will require grad
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
noisy_model_input.requires_grad_(True)
|
||||||
|
for t in text_encoder_conds:
|
||||||
|
t.requires_grad_(True)
|
||||||
|
img_ids.requires_grad_(True)
|
||||||
|
guidance_vec.requires_grad_(True)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
l_pooled, t5_out, txt_ids = text_encoder_conds
|
||||||
|
# print(
|
||||||
|
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
|
model_pred = unet(
|
||||||
|
img=packed_noisy_model_input,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=t5_out,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=l_pooled,
|
||||||
|
timesteps=timesteps / 1000,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
)
|
||||||
|
|
||||||
|
# unpack latents
|
||||||
|
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
|
||||||
|
# these weighting schemes use a uniform timestep sampling
|
||||||
|
# and instead post-weight the loss
|
||||||
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||||
|
|
||||||
|
# flow matching loss: this is different from SD3
|
||||||
|
target = noise - latents
|
||||||
|
|
||||||
|
return model_pred, target, timesteps, None, weighting
|
||||||
|
|
||||||
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = train_network.setup_parser()
|
||||||
|
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
parser.add_argument("--clip_l", type=str, help="path to clip_l")
|
||||||
|
parser.add_argument("--t5xxl", type=str, help="path to t5xxl")
|
||||||
|
parser.add_argument("--ae", type=str, help="path to ae")
|
||||||
|
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_text_encoder_outputs_to_disk",
|
||||||
|
action="store_true",
|
||||||
|
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||||
|
)
|
||||||
|
|
||||||
|
# copy from Diffusers
|
||||||
|
parser.add_argument(
|
||||||
|
"--weighting_scheme",
|
||||||
|
type=str,
|
||||||
|
default="none",
|
||||||
|
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||||
|
)
|
||||||
|
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.29,
|
||||||
|
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance_scale",
|
||||||
|
type=float,
|
||||||
|
default=3.5,
|
||||||
|
help="the FLUX.1 dev variant is a guidance distilled model",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
trainer = FluxNetworkTrainer()
|
||||||
|
trainer.train(args)
|
||||||
920
library/flux_models.py
Normal file
920
library/flux_models.py
Normal file
@@ -0,0 +1,920 @@
|
|||||||
|
# copy from FLUX repo: https://github.com/black-forest-labs/flux
|
||||||
|
# license: Apache-2.0 License
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
# USE_REENTRANT = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FluxParams:
|
||||||
|
in_channels: int
|
||||||
|
vec_in_dim: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list[int]
|
||||||
|
theta: int
|
||||||
|
qkv_bias: bool
|
||||||
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
|
# region autoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoEncoderParams:
|
||||||
|
resolution: int
|
||||||
|
in_channels: int
|
||||||
|
ch: int
|
||||||
|
out_ch: int
|
||||||
|
ch_mult: list[int]
|
||||||
|
num_res_blocks: int
|
||||||
|
z_channels: int
|
||||||
|
scale_factor: float
|
||||||
|
shift_factor: float
|
||||||
|
|
||||||
|
|
||||||
|
def swish(x: Tensor) -> Tensor:
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||||
|
|
||||||
|
def attention(self, h_: Tensor) -> Tensor:
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q = self.q(h_)
|
||||||
|
k = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
b, c, h, w = q.shape
|
||||||
|
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||||
|
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x + self.proj_out(self.attention(x))
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
out_channels = in_channels if out_channels is None else out_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
h = self.norm1(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = self.norm2(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
pad = (0, 1, 0, 1)
|
||||||
|
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
# downsampling
|
||||||
|
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
block_in = self.ch
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = ch * in_ch_mult[i_level]
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = Downsample(block_in)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ch: int,
|
||||||
|
out_ch: int,
|
||||||
|
ch_mult: list[int],
|
||||||
|
num_res_blocks: int,
|
||||||
|
in_channels: int,
|
||||||
|
resolution: int,
|
||||||
|
z_channels: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = ch
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
self.mid.attn_1 = AttnBlock(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for _ in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||||
|
block_in = block_out
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||||
|
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = swish(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussian(nn.Module):
|
||||||
|
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
self.sample = sample
|
||||||
|
self.chunk_dim = chunk_dim
|
||||||
|
|
||||||
|
def forward(self, z: Tensor) -> Tensor:
|
||||||
|
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||||
|
if self.sample:
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
return mean + std * torch.randn_like(mean)
|
||||||
|
else:
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoder(nn.Module):
|
||||||
|
def __init__(self, params: AutoEncoderParams):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = Encoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder(
|
||||||
|
resolution=params.resolution,
|
||||||
|
in_channels=params.in_channels,
|
||||||
|
ch=params.ch,
|
||||||
|
out_ch=params.out_ch,
|
||||||
|
ch_mult=params.ch_mult,
|
||||||
|
num_res_blocks=params.num_res_blocks,
|
||||||
|
z_channels=params.z_channels,
|
||||||
|
)
|
||||||
|
self.reg = DiagonalGaussian()
|
||||||
|
|
||||||
|
self.scale_factor = params.scale_factor
|
||||||
|
self.shift_factor = params.shift_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
z = self.reg(self.encoder(x))
|
||||||
|
z = self.scale_factor * (z - self.shift_factor)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: Tensor) -> Tensor:
|
||||||
|
z = z / self.scale_factor + self.shift_factor
|
||||||
|
return self.decoder(z)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.decode(self.encode(x))
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
# region config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSpec:
|
||||||
|
params: FluxParams
|
||||||
|
ae_params: AutoEncoderParams
|
||||||
|
ckpt_path: str | None
|
||||||
|
ae_path: str | None
|
||||||
|
# repo_id: str | None
|
||||||
|
# repo_flow: str | None
|
||||||
|
# repo_ae: str | None
|
||||||
|
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
"dev": ModelSpec(
|
||||||
|
# repo_id="black-forest-labs/FLUX.1-dev",
|
||||||
|
# repo_flow="flux1-dev.sft",
|
||||||
|
# repo_ae="ae.sft",
|
||||||
|
ckpt_path=None, # os.getenv("FLUX_DEV"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=True,
|
||||||
|
),
|
||||||
|
ae_path=None, # os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"schnell": ModelSpec(
|
||||||
|
# repo_id="black-forest-labs/FLUX.1-schnell",
|
||||||
|
# repo_flow="flux1-schnell.sft",
|
||||||
|
# repo_ae="ae.sft",
|
||||||
|
ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
|
||||||
|
params=FluxParams(
|
||||||
|
in_channels=64,
|
||||||
|
vec_in_dim=768,
|
||||||
|
context_in_dim=4096,
|
||||||
|
hidden_size=3072,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
num_heads=24,
|
||||||
|
depth=19,
|
||||||
|
depth_single_blocks=38,
|
||||||
|
axes_dim=[16, 56, 56],
|
||||||
|
theta=10_000,
|
||||||
|
qkv_bias=True,
|
||||||
|
guidance_embed=False,
|
||||||
|
),
|
||||||
|
ae_path=None, # os.getenv("AE"),
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region math
|
||||||
|
|
||||||
|
|
||||||
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
x = rearrange(x, "B H L D -> B L (H D)")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
# region layers
|
||||||
|
class EmbedND(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: Tensor) -> Tensor:
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = time_factor * t
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||||
|
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(t)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def _forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||||
|
else:
|
||||||
|
return self._forward(*args, **kwargs)
|
||||||
|
|
||||||
|
# def forward(self, x):
|
||||||
|
# if self.training and self.gradient_checkpointing:
|
||||||
|
# def create_custom_forward(func):
|
||||||
|
# def custom_forward(*inputs):
|
||||||
|
# return func(*inputs)
|
||||||
|
# return custom_forward
|
||||||
|
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
|
||||||
|
# else:
|
||||||
|
# return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||||
|
# return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||||
|
return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = RMSNorm(dim)
|
||||||
|
self.key_norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
q = self.query_norm(q)
|
||||||
|
k = self.key_norm(k)
|
||||||
|
return q.to(v), k.to(v)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
# self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# def enable_gradient_checkpointing(self):
|
||||||
|
# self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
x = attention(q, k, v, pe=pe)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# def forward(self, *args, **kwargs):
|
||||||
|
# if self.training and self.gradient_checkpointing:
|
||||||
|
# return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||||
|
# else:
|
||||||
|
# return self._forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModulationOut:
|
||||||
|
shift: Tensor
|
||||||
|
scale: Tensor
|
||||||
|
gate: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||||
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
ModulationOut(*out[:3]),
|
||||||
|
ModulationOut(*out[3:]) if self.is_double else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True)
|
||||||
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||||
|
|
||||||
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
# self.img_attn.enable_gradient_checkpointing()
|
||||||
|
# self.txt_attn.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
# self.img_attn.disable_gradient_checkpointing()
|
||||||
|
# self.txt_attn.disable_gradient_checkpointing()
|
||||||
|
|
||||||
|
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
# run actual attention
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
|
# calculate the img bloks
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
|
# calculate the txt bloks
|
||||||
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||||
|
else:
|
||||||
|
return self._forward(*args, **kwargs)
|
||||||
|
|
||||||
|
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||||
|
# if self.training and self.gradient_checkpointing:
|
||||||
|
# def create_custom_forward(func):
|
||||||
|
# def custom_forward(*inputs):
|
||||||
|
# return func(*inputs)
|
||||||
|
# return custom_forward
|
||||||
|
# return torch.utils.checkpoint.checkpoint(
|
||||||
|
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# return self._forward(img, txt, vec, pe)
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with parallel linear layers as described in
|
||||||
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: float | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
self.modulation = Modulation(hidden_size, double=False)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
return x + mod.gate * output
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
||||||
|
else:
|
||||||
|
return self._forward(*args, **kwargs)
|
||||||
|
|
||||||
|
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
|
||||||
|
# if self.training and self.gradient_checkpointing:
|
||||||
|
# def create_custom_forward(func):
|
||||||
|
# def custom_forward(*inputs):
|
||||||
|
# return func(*inputs)
|
||||||
|
# return custom_forward
|
||||||
|
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
|
||||||
|
# else:
|
||||||
|
# return self._forward(x, vec, pe)
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
class Flux(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params: FluxParams):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = self.in_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||||
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||||
|
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||||
|
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
|
||||||
|
self.time_in.enable_gradient_checkpointing()
|
||||||
|
self.vector_in.enable_gradient_checkpointing()
|
||||||
|
self.guidance_in.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
for block in self.double_blocks + self.single_blocks:
|
||||||
|
block.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
print("FLUX: Gradient checkpointing enabled.")
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(self):
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.time_in.disable_gradient_checkpointing()
|
||||||
|
self.vector_in.disable_gradient_checkpointing()
|
||||||
|
self.guidance_in.disable_gradient_checkpointing()
|
||||||
|
|
||||||
|
for block in self.double_blocks + self.single_blocks:
|
||||||
|
block.disable_gradient_checkpointing()
|
||||||
|
|
||||||
|
print("FLUX: Gradient checkpointing disabled.")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
if guidance is None:
|
||||||
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
for block in self.single_blocks:
|
||||||
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return img
|
||||||
215
library/flux_utils.py
Normal file
215
library/flux_utils.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import json
|
||||||
|
from typing import Union
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
||||||
|
|
||||||
|
from library import flux_models
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||||
|
|
||||||
|
|
||||||
|
def load_flow_model(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.Flux:
|
||||||
|
logger.info(f"Bulding Flux model {name}")
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||||
|
|
||||||
|
# load_sft doesn't support torch.device
|
||||||
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
|
sd = load_file(ckpt_path, device=str(device))
|
||||||
|
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
logger.info(f"Loaded Flux: {info}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_ae(name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> flux_models.AutoEncoder:
|
||||||
|
logger.info("Building AutoEncoder")
|
||||||
|
with torch.device("meta"):
|
||||||
|
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype)
|
||||||
|
|
||||||
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
|
sd = load_file(ckpt_path, device=str(device))
|
||||||
|
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
logger.info(f"Loaded AE: {info}")
|
||||||
|
return ae
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> CLIPTextModel:
|
||||||
|
logger.info("Building CLIP")
|
||||||
|
CLIPL_CONFIG = {
|
||||||
|
"_name_or_path": "clip-vit-large-patch14/",
|
||||||
|
"architectures": ["CLIPModel"],
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"logit_scale_init_value": 2.6592,
|
||||||
|
"model_type": "clip",
|
||||||
|
"projection_dim": 768,
|
||||||
|
# "text_config": {
|
||||||
|
"_name_or_path": "",
|
||||||
|
"add_cross_attention": False,
|
||||||
|
"architectures": None,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bad_words_ids": None,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"chunk_size_feed_forward": 0,
|
||||||
|
"cross_attention_hidden_size": None,
|
||||||
|
"decoder_start_token_id": None,
|
||||||
|
"diversity_penalty": 0.0,
|
||||||
|
"do_sample": False,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"early_stopping": False,
|
||||||
|
"encoder_no_repeat_ngram_size": 0,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"finetuning_task": None,
|
||||||
|
"forced_bos_token_id": None,
|
||||||
|
"forced_eos_token_id": None,
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"is_decoder": False,
|
||||||
|
"is_encoder_decoder": False,
|
||||||
|
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 20,
|
||||||
|
"max_position_embeddings": 77,
|
||||||
|
"min_length": 0,
|
||||||
|
"model_type": "clip_text_model",
|
||||||
|
"no_repeat_ngram_size": 0,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_beam_groups": 1,
|
||||||
|
"num_beams": 1,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"num_return_sequences": 1,
|
||||||
|
"output_attentions": False,
|
||||||
|
"output_hidden_states": False,
|
||||||
|
"output_scores": False,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"prefix": None,
|
||||||
|
"problem_type": None,
|
||||||
|
"projection_dim": 768,
|
||||||
|
"pruned_heads": {},
|
||||||
|
"remove_invalid_values": False,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
"return_dict": True,
|
||||||
|
"return_dict_in_generate": False,
|
||||||
|
"sep_token_id": None,
|
||||||
|
"task_specific_params": None,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"tie_encoder_decoder": False,
|
||||||
|
"tie_word_embeddings": True,
|
||||||
|
"tokenizer_class": None,
|
||||||
|
"top_k": 50,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"torch_dtype": None,
|
||||||
|
"torchscript": False,
|
||||||
|
"transformers_version": "4.16.0.dev0",
|
||||||
|
"use_bfloat16": False,
|
||||||
|
"vocab_size": 49408,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"num_attention_heads": 20,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
# },
|
||||||
|
# "text_config_dict": {
|
||||||
|
"hidden_size": 768,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"projection_dim": 768,
|
||||||
|
# },
|
||||||
|
# "torch_dtype": "float32",
|
||||||
|
# "transformers_version": None,
|
||||||
|
}
|
||||||
|
config = CLIPConfig(**CLIPL_CONFIG)
|
||||||
|
with init_empty_weights():
|
||||||
|
clip = CLIPTextModel._from_config(config)
|
||||||
|
|
||||||
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
|
sd = load_file(ckpt_path, device=str(device))
|
||||||
|
info = clip.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
logger.info(f"Loaded CLIP: {info}")
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
|
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device]) -> T5EncoderModel:
|
||||||
|
T5_CONFIG_JSON = """
|
||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"T5EncoderModel"
|
||||||
|
],
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_kv": 64,
|
||||||
|
"d_model": 4096,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dense_act_fn": "gelu_new",
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"feed_forward_proj": "gated-gelu",
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"is_gated_act": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"num_decoder_layers": 24,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_max_distance": 128,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"transformers_version": "4.41.2",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
config = json.loads(T5_CONFIG_JSON)
|
||||||
|
config = T5Config(**config)
|
||||||
|
with init_empty_weights():
|
||||||
|
t5xxl = T5EncoderModel._from_config(config)
|
||||||
|
|
||||||
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
|
sd = load_file(ckpt_path, device=str(device))
|
||||||
|
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
logger.info(f"Loaded T5xxl: {info}")
|
||||||
|
return t5xxl
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||||
|
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
|
||||||
|
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||||
|
return img_ids
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||||
|
"""
|
||||||
|
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
|
||||||
|
"""
|
||||||
|
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
return x
|
||||||
@@ -15,6 +15,12 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
memory_efficient_attention = None
|
memory_efficient_attention = None
|
||||||
@@ -95,7 +101,9 @@ class SDTokenizer:
|
|||||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||||
|
|
||||||
# truncate to max_length
|
# truncate to max_length
|
||||||
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
|
print(
|
||||||
|
f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}"
|
||||||
|
)
|
||||||
if truncate_to_max_length and len(batch) > self.max_length:
|
if truncate_to_max_length and len(batch) > self.max_length:
|
||||||
batch = batch[: self.max_length]
|
batch = batch[: self.max_length]
|
||||||
if truncate_length is not None and len(batch) > truncate_length:
|
if truncate_length is not None and len(batch) > truncate_length:
|
||||||
@@ -1554,6 +1562,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.set_clip_options({"layer": layer_idx})
|
self.set_clip_options({"layer": layer_idx})
|
||||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(self):
|
||||||
|
logger.warning("Gradient checkpointing is not supported for this model")
|
||||||
|
|
||||||
def set_attn_mode(self, mode):
|
def set_attn_mode(self, mode):
|
||||||
raise NotImplementedError("This model does not support setting the attention mode")
|
raise NotImplementedError("This model does not support setting the attention mode")
|
||||||
|
|
||||||
@@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s
|
|||||||
return_projected_pooled=False,
|
return_projected_pooled=False,
|
||||||
textmodel_json_config=CLIPL_CONFIG,
|
textmodel_json_config=CLIPL_CONFIG,
|
||||||
)
|
)
|
||||||
|
clip_l.gradient_checkpointing_enable()
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
# update state_dict if provided to include logit_scale and text_projection.weight avoid errors
|
# update state_dict if provided to include logit_scale and text_projection.weight avoid errors
|
||||||
if "logit_scale" not in state_dict:
|
if "logit_scale" not in state_dict:
|
||||||
|
|||||||
244
library/strategy_flux.py
Normal file
244
library/strategy_flux.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from library import sd3_utils, train_util
|
||||||
|
from library import sd3_models
|
||||||
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||||
|
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTokenizeStrategy(TokenizeStrategy):
|
||||||
|
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||||
|
self.t5xxl_max_length = t5xxl_max_length
|
||||||
|
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
|
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
|
||||||
|
t5_attn_mask = t5_tokens["attention_mask"]
|
||||||
|
l_tokens = l_tokens["input_ids"]
|
||||||
|
t5_tokens = t5_tokens["input_ids"]
|
||||||
|
|
||||||
|
return [l_tokens, t5_tokens, t5_attn_mask]
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
tokens: List[torch.Tensor],
|
||||||
|
apply_t5_attn_mask: bool = False,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
# supports single model inference only
|
||||||
|
|
||||||
|
clip_l, t5xxl = models
|
||||||
|
l_tokens, t5_tokens = tokens[:2]
|
||||||
|
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||||
|
|
||||||
|
if clip_l is not None and l_tokens is not None:
|
||||||
|
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
|
||||||
|
else:
|
||||||
|
l_pooled = None
|
||||||
|
|
||||||
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
|
# t5_out is [1, max length, 4096]
|
||||||
|
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
||||||
|
if apply_t5_attn_mask:
|
||||||
|
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||||
|
txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device)
|
||||||
|
else:
|
||||||
|
t5_out = None
|
||||||
|
txt_ids = None
|
||||||
|
|
||||||
|
return [l_pooled, t5_out, txt_ids]
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||||
|
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_to_disk: bool,
|
||||||
|
batch_size: int,
|
||||||
|
skip_disk_cache_validity_check: bool,
|
||||||
|
is_partial: bool = False,
|
||||||
|
apply_t5_attn_mask: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
|
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||||
|
|
||||||
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
|
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if "l_pooled" not in npz:
|
||||||
|
return False
|
||||||
|
if "t5_out" not in npz:
|
||||||
|
return False
|
||||||
|
if "txt_ids" not in npz:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
|
||||||
|
return t5_out * np.expand_dims(t5_attn_mask, -1)
|
||||||
|
|
||||||
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
|
data = np.load(npz_path)
|
||||||
|
l_pooled = data["l_pooled"]
|
||||||
|
t5_out = data["t5_out"]
|
||||||
|
txt_ids = data["txt_ids"]
|
||||||
|
|
||||||
|
if self.apply_t5_attn_mask:
|
||||||
|
t5_attn_mask = data["t5_attn_mask"]
|
||||||
|
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||||
|
|
||||||
|
return [l_pooled, t5_out, txt_ids]
|
||||||
|
|
||||||
|
def cache_batch_outputs(
|
||||||
|
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||||
|
):
|
||||||
|
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||||
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
l_pooled, t5_out, txt_ids = flux_text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, models, tokens_and_masks, self.apply_t5_attn_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if l_pooled.dtype == torch.bfloat16:
|
||||||
|
l_pooled = l_pooled.float()
|
||||||
|
if t5_out.dtype == torch.bfloat16:
|
||||||
|
t5_out = t5_out.float()
|
||||||
|
if txt_ids.dtype == torch.bfloat16:
|
||||||
|
txt_ids = txt_ids.float()
|
||||||
|
|
||||||
|
l_pooled = l_pooled.cpu().numpy()
|
||||||
|
t5_out = t5_out.cpu().numpy()
|
||||||
|
txt_ids = txt_ids.cpu().numpy()
|
||||||
|
|
||||||
|
for i, info in enumerate(infos):
|
||||||
|
l_pooled_i = l_pooled[i]
|
||||||
|
t5_out_i = t5_out[i]
|
||||||
|
txt_ids_i = txt_ids[i]
|
||||||
|
|
||||||
|
if self.cache_to_disk:
|
||||||
|
t5_attn_mask = tokens_and_masks[2]
|
||||||
|
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy()
|
||||||
|
np.savez(
|
||||||
|
info.text_encoder_outputs_npz,
|
||||||
|
l_pooled=l_pooled_i,
|
||||||
|
t5_out=t5_out_i,
|
||||||
|
txt_ids=txt_ids_i,
|
||||||
|
t5_attn_mask=t5_attn_mask_i,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
|
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||||
|
|
||||||
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
|
||||||
|
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
|
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX)
|
||||||
|
if len(npz_file) == 0:
|
||||||
|
return None, None
|
||||||
|
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||||
|
return int(w), int(h)
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
return (
|
||||||
|
os.path.splitext(absolute_path)[0]
|
||||||
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||||
|
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
|
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||||
|
|
||||||
|
# TODO remove circular dependency for ImageInfo
|
||||||
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||||
|
vae_device = vae.device
|
||||||
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
|
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||||
|
|
||||||
|
if not train_util.HIGH_VRAM:
|
||||||
|
train_util.clean_memory_on_device(vae.device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# test code for FluxTokenizeStrategy
|
||||||
|
# tokenizer = sd3_models.SD3Tokenizer()
|
||||||
|
strategy = FluxTokenizeStrategy(256)
|
||||||
|
text = "hello world"
|
||||||
|
|
||||||
|
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||||
|
# print(l_tokens.shape)
|
||||||
|
print(l_tokens)
|
||||||
|
print(g_tokens)
|
||||||
|
print(t5_tokens)
|
||||||
|
|
||||||
|
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||||
|
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
t5_tokens_2 = strategy.t5xxl(
|
||||||
|
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
print(l_tokens_2)
|
||||||
|
print(g_tokens_2)
|
||||||
|
print(t5_tokens_2)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||||
|
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||||
|
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||||
|
|
||||||
|
text = ",".join(["hello world! this is long text"] * 50)
|
||||||
|
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||||
|
print(l_tokens)
|
||||||
|
print(g_tokens)
|
||||||
|
print(t5_tokens)
|
||||||
|
|
||||||
|
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||||
|
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||||
|
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||||
730
networks/lora_flux.py
Normal file
730
networks/lora_flux.py
Normal file
@@ -0,0 +1,730 @@
|
|||||||
|
# temporary minimum implementation of LoRA
|
||||||
|
# FLUX doesn't have Conv2d, so we ignore it
|
||||||
|
# TODO commonize with the original implementation
|
||||||
|
|
||||||
|
# LoRA network module
|
||||||
|
# reference:
|
||||||
|
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||||
|
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
from library.utils import setup_logging
|
||||||
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
dropout=None,
|
||||||
|
rank_dropout=None,
|
||||||
|
module_dropout=None,
|
||||||
|
):
|
||||||
|
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||||
|
super().__init__()
|
||||||
|
self.lora_name = lora_name
|
||||||
|
|
||||||
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
|
in_dim = org_module.in_channels
|
||||||
|
out_dim = org_module.out_channels
|
||||||
|
else:
|
||||||
|
in_dim = org_module.in_features
|
||||||
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
|
kernel_size = org_module.kernel_size
|
||||||
|
stride = org_module.stride
|
||||||
|
padding = org_module.padding
|
||||||
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||||
|
else:
|
||||||
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||||
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
|
if type(alpha) == torch.Tensor:
|
||||||
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
|
self.scale = alpha / self.lora_dim
|
||||||
|
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
|
# same as microsoft's
|
||||||
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
|
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.org_module = org_module # remove in applying
|
||||||
|
self.dropout = dropout
|
||||||
|
self.rank_dropout = rank_dropout
|
||||||
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
def apply_to(self):
|
||||||
|
self.org_forward = self.org_module.forward
|
||||||
|
self.org_module.forward = self.forward
|
||||||
|
del self.org_module
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
org_forwarded = self.org_forward(x)
|
||||||
|
|
||||||
|
# module dropout
|
||||||
|
if self.module_dropout is not None and self.training:
|
||||||
|
if torch.rand(1) < self.module_dropout:
|
||||||
|
return org_forwarded
|
||||||
|
|
||||||
|
lx = self.lora_down(x)
|
||||||
|
|
||||||
|
# normal dropout
|
||||||
|
if self.dropout is not None and self.training:
|
||||||
|
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||||
|
|
||||||
|
# rank dropout
|
||||||
|
if self.rank_dropout is not None and self.training:
|
||||||
|
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||||
|
if len(lx.size()) == 3:
|
||||||
|
mask = mask.unsqueeze(1) # for Text Encoder
|
||||||
|
elif len(lx.size()) == 4:
|
||||||
|
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||||
|
lx = lx * mask
|
||||||
|
|
||||||
|
# scaling for rank dropout: treat as if the rank is changed
|
||||||
|
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||||
|
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||||
|
else:
|
||||||
|
scale = self.scale
|
||||||
|
|
||||||
|
lx = self.lora_up(lx)
|
||||||
|
|
||||||
|
return org_forwarded + lx * self.multiplier * scale
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAInfModule(LoRAModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# no dropout for inference
|
||||||
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||||
|
|
||||||
|
self.org_module_ref = [org_module] # 後から参照できるように
|
||||||
|
self.enabled = True
|
||||||
|
self.network: LoRANetwork = None
|
||||||
|
|
||||||
|
def set_network(self, network):
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
# freezeしてマージする
|
||||||
|
def merge_to(self, sd, dtype, device):
|
||||||
|
# extract weight from org_module
|
||||||
|
org_sd = self.org_module.state_dict()
|
||||||
|
weight = org_sd["weight"]
|
||||||
|
org_dtype = weight.dtype
|
||||||
|
org_device = weight.device
|
||||||
|
weight = weight.to(torch.float) # calc in float
|
||||||
|
|
||||||
|
if dtype is None:
|
||||||
|
dtype = org_dtype
|
||||||
|
if device is None:
|
||||||
|
device = org_device
|
||||||
|
|
||||||
|
# get up/down weight
|
||||||
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||||
|
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||||
|
|
||||||
|
# merge weight
|
||||||
|
if len(weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
weight
|
||||||
|
+ self.multiplier
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
|
# set weight to org_module
|
||||||
|
org_sd["weight"] = weight.to(dtype)
|
||||||
|
self.org_module.load_state_dict(org_sd)
|
||||||
|
|
||||||
|
# 復元できるマージのため、このモジュールのweightを返す
|
||||||
|
def get_weight(self, multiplier=None):
|
||||||
|
if multiplier is None:
|
||||||
|
multiplier = self.multiplier
|
||||||
|
|
||||||
|
# get up/down weight from module
|
||||||
|
up_weight = self.lora_up.weight.to(torch.float)
|
||||||
|
down_weight = self.lora_down.weight.to(torch.float)
|
||||||
|
|
||||||
|
# pre-calculated weight
|
||||||
|
if len(down_weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
self.multiplier
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
weight = self.multiplier * conved * self.scale
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def set_region(self, region):
|
||||||
|
self.region = region
|
||||||
|
self.region_mask = None
|
||||||
|
|
||||||
|
def default_forward(self, x):
|
||||||
|
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.enabled:
|
||||||
|
return self.org_forward(x)
|
||||||
|
return self.default_forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(
|
||||||
|
multiplier: float,
|
||||||
|
network_dim: Optional[int],
|
||||||
|
network_alpha: Optional[float],
|
||||||
|
ae: AutoencoderKL,
|
||||||
|
text_encoders: List[CLIPTextModel],
|
||||||
|
flux,
|
||||||
|
neuron_dropout: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if network_dim is None:
|
||||||
|
network_dim = 4 # default
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = 1.0
|
||||||
|
|
||||||
|
# extract dim/alpha for conv2d, and block dim
|
||||||
|
conv_dim = kwargs.get("conv_dim", None)
|
||||||
|
conv_alpha = kwargs.get("conv_alpha", None)
|
||||||
|
if conv_dim is not None:
|
||||||
|
conv_dim = int(conv_dim)
|
||||||
|
if conv_alpha is None:
|
||||||
|
conv_alpha = 1.0
|
||||||
|
else:
|
||||||
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
|
# rank/module dropout
|
||||||
|
rank_dropout = kwargs.get("rank_dropout", None)
|
||||||
|
if rank_dropout is not None:
|
||||||
|
rank_dropout = float(rank_dropout)
|
||||||
|
module_dropout = kwargs.get("module_dropout", None)
|
||||||
|
if module_dropout is not None:
|
||||||
|
module_dropout = float(module_dropout)
|
||||||
|
|
||||||
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
|
network = LoRANetwork(
|
||||||
|
text_encoders,
|
||||||
|
flux,
|
||||||
|
multiplier=multiplier,
|
||||||
|
lora_dim=network_dim,
|
||||||
|
alpha=network_alpha,
|
||||||
|
dropout=neuron_dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
conv_lora_dim=conv_dim,
|
||||||
|
conv_alpha=conv_alpha,
|
||||||
|
varbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||||
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||||
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||||
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
|
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
|
||||||
|
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||||
|
if weights_sd is None:
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
# get dim/alpha mapping
|
||||||
|
modules_dim = {}
|
||||||
|
modules_alpha = {}
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if "." not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
if "alpha" in key:
|
||||||
|
modules_alpha[lora_name] = value
|
||||||
|
elif "lora_down" in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
|
network = LoRANetwork(text_encoders, flux, multiplier=multiplier, module_class=module_class)
|
||||||
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
|
class LoRANetwork(torch.nn.Module):
|
||||||
|
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_FLUX = "lora_flux"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
|
||||||
|
unet,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
lora_dim: int = 4,
|
||||||
|
alpha: float = 1,
|
||||||
|
dropout: Optional[float] = None,
|
||||||
|
rank_dropout: Optional[float] = None,
|
||||||
|
module_dropout: Optional[float] = None,
|
||||||
|
conv_lora_dim: Optional[int] = None,
|
||||||
|
conv_alpha: Optional[float] = None,
|
||||||
|
module_class: Type[object] = LoRAModule,
|
||||||
|
varbose: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
self.conv_lora_dim = conv_lora_dim
|
||||||
|
self.conv_alpha = conv_alpha
|
||||||
|
self.dropout = dropout
|
||||||
|
self.rank_dropout = rank_dropout
|
||||||
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
self.loraplus_lr_ratio = None
|
||||||
|
self.loraplus_unet_lr_ratio = None
|
||||||
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
logger.info(
|
||||||
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
|
)
|
||||||
|
if self.conv_lora_dim is not None:
|
||||||
|
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
|
# create module instances
|
||||||
|
def create_modules(
|
||||||
|
is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str]
|
||||||
|
) -> List[LoRAModule]:
|
||||||
|
prefix = (
|
||||||
|
self.LORA_PREFIX_FLUX
|
||||||
|
if is_flux
|
||||||
|
else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5)
|
||||||
|
)
|
||||||
|
|
||||||
|
loras = []
|
||||||
|
skipped = []
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_linear or is_conv2d:
|
||||||
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
|
||||||
|
# 通常、すべて対象とする
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
elif self.conv_lora_dim is not None:
|
||||||
|
dim = self.conv_lora_dim
|
||||||
|
alpha = self.conv_alpha
|
||||||
|
|
||||||
|
if dim is None or dim == 0:
|
||||||
|
# skipした情報を出力
|
||||||
|
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
|
||||||
|
skipped.append(lora_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora = module_class(
|
||||||
|
lora_name,
|
||||||
|
child_module,
|
||||||
|
self.multiplier,
|
||||||
|
dim,
|
||||||
|
alpha,
|
||||||
|
dropout=dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
)
|
||||||
|
loras.append(lora)
|
||||||
|
return loras, skipped
|
||||||
|
|
||||||
|
# create LoRA for text encoder
|
||||||
|
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||||
|
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||||
|
skipped_te = []
|
||||||
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
|
index = i
|
||||||
|
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||||
|
|
||||||
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
|
skipped_te += skipped
|
||||||
|
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
|
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||||
|
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE)
|
||||||
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
|
skipped = skipped_te + skipped_un
|
||||||
|
if varbose and len(skipped) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
|
)
|
||||||
|
for name in skipped:
|
||||||
|
logger.info(f"\t{name}")
|
||||||
|
|
||||||
|
# assertion
|
||||||
|
names = set()
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def set_enabled(self, is_enabled):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.enabled = is_enabled
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
info = self.load_state_dict(weights_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
|
||||||
|
if apply_text_encoder:
|
||||||
|
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.apply_to()
|
||||||
|
self.add_module(lora.lora_name, lora)
|
||||||
|
|
||||||
|
# マージできるかどうかを返す
|
||||||
|
def is_mergeable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO refactor to common function with apply_to
|
||||||
|
def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None):
|
||||||
|
apply_text_encoder = apply_unet = False
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
|
if apply_text_encoder:
|
||||||
|
logger.info("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
logger.info("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
|
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||||
|
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||||
|
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||||
|
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||||
|
|
||||||
|
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
|
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
|
|
||||||
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
|
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||||
|
# if (
|
||||||
|
# self.loraplus_lr_ratio is not None
|
||||||
|
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||||
|
# or self.loraplus_unet_lr_ratio is not None
|
||||||
|
# ):
|
||||||
|
# assert (
|
||||||
|
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||||
|
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||||
|
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
all_params = []
|
||||||
|
lr_descriptions = []
|
||||||
|
|
||||||
|
def assemble_params(loras, lr, ratio):
|
||||||
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
|
for lora in loras:
|
||||||
|
for name, param in lora.named_parameters():
|
||||||
|
if ratio is not None and "lora_up" in name:
|
||||||
|
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
else:
|
||||||
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
|
||||||
|
params = []
|
||||||
|
descriptions = []
|
||||||
|
for key in param_groups.keys():
|
||||||
|
param_data = {"params": param_groups[key].values()}
|
||||||
|
|
||||||
|
if len(param_data["params"]) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if lr is not None:
|
||||||
|
if key == "plus":
|
||||||
|
param_data["lr"] = lr * ratio
|
||||||
|
else:
|
||||||
|
param_data["lr"] = lr
|
||||||
|
|
||||||
|
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||||
|
logger.info("NO LR skipping!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
params.append(param_data)
|
||||||
|
descriptions.append("plus" if key == "plus" else "")
|
||||||
|
|
||||||
|
return params, descriptions
|
||||||
|
|
||||||
|
if self.text_encoder_loras:
|
||||||
|
params, descriptions = assemble_params(
|
||||||
|
self.text_encoder_loras,
|
||||||
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
|
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||||
|
)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
|
if self.unet_loras:
|
||||||
|
# if self.block_lr:
|
||||||
|
# is_sdxl = False
|
||||||
|
# for lora in self.unet_loras:
|
||||||
|
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
||||||
|
# is_sdxl = True
|
||||||
|
# break
|
||||||
|
|
||||||
|
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||||
|
# block_idx_to_lora = {}
|
||||||
|
# for lora in self.unet_loras:
|
||||||
|
# idx = get_block_index(lora.lora_name, is_sdxl)
|
||||||
|
# if idx not in block_idx_to_lora:
|
||||||
|
# block_idx_to_lora[idx] = []
|
||||||
|
# block_idx_to_lora[idx].append(lora)
|
||||||
|
|
||||||
|
# # blockごとにパラメータを設定する
|
||||||
|
# for idx, block_loras in block_idx_to_lora.items():
|
||||||
|
# params, descriptions = assemble_params(
|
||||||
|
# block_loras,
|
||||||
|
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
||||||
|
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||||
|
# )
|
||||||
|
# all_params.extend(params)
|
||||||
|
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
|
# else:
|
||||||
|
params, descriptions = assemble_params(
|
||||||
|
self.unet_loras,
|
||||||
|
unet_lr if unet_lr is not None else default_lr,
|
||||||
|
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||||
|
)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
|
return all_params, lr_descriptions
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
# not supported
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
def on_epoch_start(self, text_encoder, unet):
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
def get_trainable_params(self):
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if metadata is not None and len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
save_file(state_dict, file, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
def backup_weights(self):
|
||||||
|
# 重みのバックアップを行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not hasattr(org_module, "_lora_org_weight"):
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def restore_weights(self):
|
||||||
|
# 重みのリストアを行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not org_module._lora_restored:
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
sd["weight"] = org_module._lora_org_weight
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def pre_calculation(self):
|
||||||
|
# 事前計算を行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
|
||||||
|
org_weight = sd["weight"]
|
||||||
|
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
sd["weight"] = org_weight + lora_weight
|
||||||
|
assert sd["weight"].shape == org_weight.shape
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
|
||||||
|
org_module._lora_restored = False
|
||||||
|
lora.enabled = False
|
||||||
|
|
||||||
|
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||||
|
downkeys = []
|
||||||
|
upkeys = []
|
||||||
|
alphakeys = []
|
||||||
|
norms = []
|
||||||
|
keys_scaled = 0
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "lora_down" in key and "weight" in key:
|
||||||
|
downkeys.append(key)
|
||||||
|
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||||
|
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||||
|
|
||||||
|
for i in range(len(downkeys)):
|
||||||
|
down = state_dict[downkeys[i]].to(device)
|
||||||
|
up = state_dict[upkeys[i]].to(device)
|
||||||
|
alpha = state_dict[alphakeys[i]].to(device)
|
||||||
|
dim = down.shape[0]
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
|
||||||
|
updown *= scale
|
||||||
|
|
||||||
|
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||||
|
desired = torch.clamp(norm, max=max_norm_value)
|
||||||
|
ratio = desired.cpu() / norm.cpu()
|
||||||
|
sqrt_ratio = ratio**0.5
|
||||||
|
if ratio != 1:
|
||||||
|
keys_scaled += 1
|
||||||
|
state_dict[upkeys[i]] *= sqrt_ratio
|
||||||
|
state_dict[downkeys[i]] *= sqrt_ratio
|
||||||
|
scalednorm = updown.norm() * ratio
|
||||||
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
@@ -52,6 +52,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.logit_scale = logit_scale
|
self.logit_scale = logit_scale
|
||||||
self.ckpt_info = ckpt_info
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
|
||||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
def get_tokenize_strategy(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
|
|||||||
157
train_network.py
157
train_network.py
@@ -100,6 +100,12 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||||
|
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
|
||||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||||
|
|
||||||
def get_tokenize_strategy(self, args):
|
def get_tokenize_strategy(self, args):
|
||||||
@@ -147,6 +153,81 @@ class NetworkTrainer:
|
|||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
|
||||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
|
||||||
|
|
||||||
|
# region SD/SDXL
|
||||||
|
|
||||||
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||||
|
)
|
||||||
|
prepare_scheduler_for_custom_training(noise_scheduler, device)
|
||||||
|
if args.zero_terminal_snr:
|
||||||
|
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||||
|
return noise_scheduler
|
||||||
|
|
||||||
|
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||||
|
return vae.encode(images).latent_dist.sample()
|
||||||
|
|
||||||
|
def shift_scale_latents(self, args, latents):
|
||||||
|
return latents * self.vae_scale_factor
|
||||||
|
|
||||||
|
def get_noise_pred_and_target(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
noise_scheduler,
|
||||||
|
latents,
|
||||||
|
batch,
|
||||||
|
text_encoder_conds,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
weight_dtype,
|
||||||
|
train_unet,
|
||||||
|
):
|
||||||
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
|
# with noise offset and/or multires noise if specified
|
||||||
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||||
|
|
||||||
|
# ensure the hidden state will require grad
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
for x in noisy_latents:
|
||||||
|
x.requires_grad_(True)
|
||||||
|
for t in text_encoder_conds:
|
||||||
|
t.requires_grad_(True)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
|
noise_pred = self.call_unet(
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
unet,
|
||||||
|
noisy_latents.requires_grad_(train_unet),
|
||||||
|
timesteps,
|
||||||
|
text_encoder_conds,
|
||||||
|
batch,
|
||||||
|
weight_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
return noise_pred, target, timesteps, huber_c, None
|
||||||
|
|
||||||
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||||
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.v_pred_like_loss:
|
||||||
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
training_started_at = time.time()
|
training_started_at = time.time()
|
||||||
@@ -253,11 +334,6 @@ class NetworkTrainer:
|
|||||||
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
||||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
|
||||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
|
||||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
|
||||||
|
|
||||||
# 差分追加学習のためにモデルを読み込む
|
# 差分追加学習のためにモデルを読み込む
|
||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
accelerator.print("import network module:", args.network_module)
|
accelerator.print("import network module:", args.network_module)
|
||||||
@@ -445,14 +521,17 @@ class NetworkTrainer:
|
|||||||
unet_weight_dtype = torch.float8_e4m3fn
|
unet_weight_dtype = torch.float8_e4m3fn
|
||||||
te_weight_dtype = torch.float8_e4m3fn
|
te_weight_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
unet.to(accelerator.device) # this makes faster `to(dtype)` below
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(dtype=unet_weight_dtype)
|
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.requires_grad_(False)
|
t_enc.requires_grad_(False)
|
||||||
|
|
||||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||||
if t_enc.device.type != "cpu":
|
if t_enc.device.type != "cpu":
|
||||||
t_enc.to(dtype=te_weight_dtype)
|
t_enc.to(dtype=te_weight_dtype)
|
||||||
|
if hasattr(t_enc.text_model, "embeddings"):
|
||||||
# nn.Embedding not support FP8
|
# nn.Embedding not support FP8
|
||||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
|
|
||||||
@@ -851,12 +930,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
|
||||||
)
|
|
||||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
|
||||||
if args.zero_terminal_snr:
|
|
||||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
@@ -913,6 +987,13 @@ class NetworkTrainer:
|
|||||||
initial_step -= len(train_dataloader)
|
initial_step -= len(train_dataloader)
|
||||||
global_step = initial_step
|
global_step = initial_step
|
||||||
|
|
||||||
|
# log device and dtype for each model
|
||||||
|
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||||
|
for t_enc in text_encoders:
|
||||||
|
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
|
||||||
|
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
for epoch in range(epoch_to_start, num_train_epochs):
|
for epoch in range(epoch_to_start, num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -940,13 +1021,15 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
|
||||||
|
latents = latents.to(dtype=weight_dtype)
|
||||||
|
|
||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * self.vae_scale_factor
|
|
||||||
|
latents = self.shift_scale_latents(args, latents)
|
||||||
|
|
||||||
# get multiplier for each sample
|
# get multiplier for each sample
|
||||||
if network_has_multiplier:
|
if network_has_multiplier:
|
||||||
@@ -985,41 +1068,25 @@ class NetworkTrainer:
|
|||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# sample noise, call unet, get target
|
||||||
# with noise offset and/or multires noise if specified
|
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
|
||||||
args, noise_scheduler, latents
|
|
||||||
)
|
|
||||||
|
|
||||||
# ensure the hidden state will require grad
|
|
||||||
if args.gradient_checkpointing:
|
|
||||||
for x in noisy_latents:
|
|
||||||
x.requires_grad_(True)
|
|
||||||
for t in text_encoder_conds:
|
|
||||||
t.requires_grad_(True)
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
with accelerator.autocast():
|
|
||||||
noise_pred = self.call_unet(
|
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
unet,
|
noise_scheduler,
|
||||||
noisy_latents.requires_grad_(train_unet),
|
latents,
|
||||||
timesteps,
|
|
||||||
text_encoder_conds,
|
|
||||||
batch,
|
batch,
|
||||||
|
text_encoder_conds,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
|
train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
# v-parameterization training
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = train_util.conditional_loss(
|
loss = train_util.conditional_loss(
|
||||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
)
|
)
|
||||||
|
if weighting is not None:
|
||||||
|
loss = loss * weighting
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -1027,14 +1094,8 @@ class NetworkTrainer:
|
|||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
|
||||||
if args.v_pred_like_loss:
|
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
|
||||||
if args.debiased_estimation_loss:
|
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user