Merge remote-tracking branch 'hina/feature/val-loss' into validation-loss-upstream

Modified implementation for process_batch and cleanup validation
recording
This commit is contained in:
rockerBOO
2025-01-03 00:48:08 -05:00
85 changed files with 23666 additions and 1552 deletions

138
library/adafactor_fused.py Normal file
View File

@@ -0,0 +1,138 @@
import math
import torch
from transformers import Adafactor
# stochastic rounding for bfloat16
# The implementation was provided by 2kpr. Thank you very much!
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
"""
copies source into target using stochastic rounding
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
# create a random 16 bit integer
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

View File

@@ -10,13 +10,7 @@ import json
from pathlib import Path
# from toolz import curry
from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Dict, List, Optional, Sequence, Tuple, Union
import toml
import voluptuous
@@ -78,6 +72,9 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
@dataclass
@@ -86,11 +83,13 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
alpha_mask: bool = False
@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
alpha_mask: bool = False
@dataclass
@@ -102,14 +101,13 @@ class ControlNetSubsetParams(BaseSubsetParams):
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
batch_size: int = 1
@@ -191,11 +189,13 @@ class ConfigSanitizer:
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"caption_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -212,11 +212,13 @@ class ConfigSanitizer:
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
@@ -480,7 +482,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
datasets.append(dataset)
val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
@@ -488,17 +490,17 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)
# print info
@@ -543,6 +545,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")
if is_dreambooth:
@@ -564,6 +568,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
print("Validation dataset")
print_info(val_datasets)
if len(val_datasets) > 0:
info = ""
for i, dataset in enumerate(val_datasets):
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Validation Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
"""
),
" ",
)
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
@@ -574,7 +622,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset.set_seed(seed)
for i, dataset in enumerate(val_datasets):
print(f"[Validation Dataset {i}]")
logger.info(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

View File

@@ -0,0 +1,227 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# print(module_to_cpu.__class__, module_to_cuda.__class__)
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
else:
if module_to_cuda.weight.data.device.type != device.type:
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
"""
not tested
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
class Offloader:
"""
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
self.debug = debug
self.thread_pool = ThreadPoolExecutor(max_workers=1)
self.futures = {}
self.cuda_available = device.type == "cuda"
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
)
def _wait_blocks_move(self, block_idx):
if block_idx not in self.futures:
return
if self.debug:
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
_, bidx_to_cuda = future.result()
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def __del__(self):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")
if swapping:
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
if waiting:
self._wait_blocks_move(block_idx_to_wait)
return None
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

View File

@@ -98,10 +98,13 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1 / torch.sqrt(snr_t)
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
@@ -482,12 +485,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss

1493
library/flux_models.py Normal file

File diff suppressed because it is too large Load Diff

619
library/flux_train_utils.py Normal file
View File

@@ -0,0 +1,619 @@
import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, PartialState
from transformers import CLIPTextModel
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sample images
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
flux,
ae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad():
x = ae.decode(x)
ae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
img_ids=img_ids,
controlnet_cond=controlnet_img,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
# endregion
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, weighting
def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("", flux.state_dict())
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_flux_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_flux_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
flux: flux_models.Flux,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# endregion
def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument(
"--t5xxl",
type=str,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

488
library/flux_utils.py Normal file
View File

@@ -0,0 +1,488 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = []
for ckpt_path in ckpt_paths:
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())
# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
# check number of double and single blocks
if not is_diffusers:
max_double_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
)
max_single_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
)
else:
max_double_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
]
)
max_single_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
]
)
num_double_blocks = max_double_block_index + 1
num_single_blocks = max_single_block_index + 1
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_controlnet(
ckpt_path: Optional[str], is_schnell: bool, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device(device):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
if ckpt_path is not None:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
"initializer_factor": 1.0,
"logit_scale_init_value": 2.6592,
"model_type": "clip",
"projection_dim": 768,
# "text_config": {
"_name_or_path": "",
"add_cross_attention": False,
"architectures": None,
"attention_dropout": 0.0,
"bad_words_ids": None,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"dropout": 0.0,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 77,
"min_length": 0,
"model_type": "clip_text_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": 1,
"prefix": None,
"problem_type": None,
"projection_dim": 768,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"sep_token_id": None,
"task_specific_params": None,
"temperature": 1.0,
"tie_encoder_decoder": False,
"tie_word_embeddings": True,
"tokenizer_class": None,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": None,
"torchscript": False,
"transformers_version": "4.16.0.dev0",
"use_bfloat16": False,
"vocab_size": 49408,
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
# },
# "text_config_dict": {
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"projection_dim": 768,
# },
# "torch_dtype": "float32",
# "transformers_version": None,
}
config = CLIPConfig(**CLIPL_CONFIG)
with init_empty_weights():
clip = CLIPTextModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_t5xxl(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"use_cache": true,
"vocab_size": 32128
}
"""
config = json.loads(T5_CONFIG_JSON)
config = T5Config(**config)
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
# region Diffusers
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(num_double_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(num_single_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
return diffusers_to_bfl_map
def convert_diffusers_sd_to_bfl(
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for diffusers_key, tensor in diffusers_sd.items():
if diffusers_key in diffusers_to_bfl_map:
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
return flux_sd
# endregion

View File

@@ -5,7 +5,7 @@ from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))

View File

@@ -6,8 +6,10 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -55,12 +57,18 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -113,7 +121,12 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
"""
sd3: only supports "m", flux: only supports "dev"
"""
# if state_dict is None, hash is not calculated
metadata = {}
@@ -126,6 +139,13 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -142,9 +162,12 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -202,7 +225,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
if sdxl or sd3 is not None or flux is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
@@ -213,7 +236,9 @@ def build_metadata(
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
@@ -236,7 +261,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -250,7 +275,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

1428
library/sd3_models.py Normal file

File diff suppressed because it is too large Load Diff

945
library/sd3_train_utils.py Normal file
View File

@@ -0,0 +1,945 @@
import argparse
import math
import os
import toml
import json
import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
# from transformers import CLIPTokenizer
# from library import model_util
# , sdxl_model_util, train_util, sdxl_original_unet
# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_base, train_util
def save_models(
ckpt_path: str,
mmdit: Optional[sd3_models.MMDiT],
vae: Optional[sd3_models.SDVAE],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
):
r"""
Save models to checkpoint file. Only supports unified checkpoint format.
"""
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("model.diffusion_model.", mmdit.state_dict())
update_sd("first_stage_model.", vae.state_dict())
# do not support unified checkpoint format for now
# if clip_l is not None:
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
# if clip_g is not None:
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
# if t5xxl is not None:
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if clip_l is not None:
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
save_file(clip_l.state_dict(), clip_l_path)
if clip_g is not None:
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
t5xxl_state_dict = t5xxl.state_dict()
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
shared_weight = t5xxl_state_dict["shared.weight"]
shared_weight_copy = shared_weight.detach().clone()
t5xxl_state_dict["shared.weight"] = shared_weight_copy
save_file(t5xxl_state_dict, t5xxl_path)
def save_sd3_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd3_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
required=False,
help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--clip_g",
type=str,
required=False,
help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--t5xxl",
type=str,
required=False,
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--save_clip",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--save_t5xxl",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--t5xxl_device",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
)
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--pos_emb_random_crop_rate",
type=float,
default=0.0,
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)
parser.add_argument(
"--enable_scaled_pos_embed",
action="store_true",
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# Dependencies of Diffusers noise sampler has been removed for clarity in training
parser.add_argument(
"--training_shift",
type=float,
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
# temporary copied from sd3_minimal_inferece.py
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def max_denoise(model_sampling, sigmas):
max_sigma = float(model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def do_sample(
height: int,
width: int,
seed: int,
cond: Tuple[torch.Tensor, torch.Tensor],
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
guidance_scale: float,
dtype: torch.dtype,
device: str,
):
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
latent = latent.to(dtype).to(device)
# noise = get_noise(seed, latent).to(device)
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = (
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
.to(latent.dtype)
.to(device)
)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_all_sigmas(model_sampling, steps).to(device)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
# with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
mmdit.prepare_block_swap_before_forward()
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
mmdit,
vae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit)
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
mmdit: sd3_models.MMDiT,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
vae: sd3_models.SDVAE,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
clean_memory_on_device(accelerator.device)
with accelerator.autocast(), torch.no_grad():
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# region Diffusers
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# endregion
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)
# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas

302
library/sd3_utils.py Normal file
View File

@@ -0,0 +1,302 @@
from dataclasses import dataclass
import math
import re
from typing import Dict, List, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models
# TODO move some of functions to model_util.py
from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
logger.info(f"Analyzing state dict state...")
# analyze configs
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
model_type = "3-medium"
params = sd3_models.SD3Params(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_in_features=context_embedder_in_features,
context_embedder_out_features=context_embedder_out_features,
model_type=model_type,
)
logger.info(f"Analyzed state dict state: {params}")
return params
def load_mmdit(
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
) -> sd3_models.MMDiT:
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load MMDiT
logger.info("Building MMDit")
params = analyze_state_dict_state(mmdit_sd)
with init_empty_weights():
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...")
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
clip_l_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_l_sd = None
if clip_l_path is None:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_l_path is None:
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
return None
# load clip_l
logger.info("Building CLIP-L")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_l_sd is None:
logger.info(f"Loading state dict from {clip_l_path}")
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if "text_projection.weight" not in clip_l_sd:
logger.info("Adding text_projection.weight to clip_l_sd")
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_clip_g(
clip_g_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_g_sd = None
if state_dict is not None:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_g_path is None:
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
return None
# load clip_g
logger.info("Building CLIP-G")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_g_sd is None:
logger.info(f"Loading state dict from {clip_g_path}")
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-G: {info}")
return clip
def load_t5xxl(
t5xxl_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
t5xxl_sd = None
if state_dict is not None:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
elif t5xxl_path is None:
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
return None
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
def load_vae(
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE(vae_dtype, device)
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
return vae
# endregion
class ModelSamplingDiscreteFlow:
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift=1.0):
self.shift = shift
timesteps = 1000
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image

View File

@@ -13,12 +13,20 @@ from tqdm import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.models import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import logging
from PIL import Image
from library import sdxl_model_util, sdxl_train_util, train_util
from library import (
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sdxl,
train_util,
sdxl_original_unet,
sdxl_original_control_net,
)
try:
@@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
vae: AutoencoderKL,
text_encoder: List[CLIPTextModel],
tokenizer: List[CLIPTokenizer],
unet: UNet2DConditionModel,
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
@@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if text_pool is not None:
text_pool = text_pool.repeat(1, num_images_per_prompt)
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if uncond_pool is not None:
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
return text_embeddings, text_pool, None, None
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet: sdxl_original_control_net.SdxlControlNet = None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
text_embeddings_list = []
text_pool = None
uncond_embeddings_list = []
uncond_pool = None
for i in range(len(self.tokenizers)):
self.tokenizer = self.tokenizers[i]
self.text_encoder = self.text_encoders[i]
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2=i == 1,
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
)
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
if do_classifier_free_guidance:
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, input_ids, weights
)
text_embeddings_list.append(text_embeddings)
uncond_embeddings_list.append(uncond_embeddings)
if tp1 is not None:
text_pool = tp1
if up1 is not None:
uncond_pool = up1
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
else:
uncond_embeddings = None
uncond_pool = None
unet_dtype = self.unet.dtype
dtype = unet_dtype
@@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# create size embs and concat embeddings for SDXL
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
crop_size = torch.zeros_like(orig_size)
target_size = orig_size
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
# make conditionings
text_pool = text_pool.to(device, dtype)
if do_classifier_free_guidance:
text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
uncond_pool = uncond_pool.to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
vector_embedding = torch.cat([uncond_vector, cond_vector])
else:
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
text_embedding = text_embeddings.to(device, dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
@@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# FIXME SD1 ControlNet is not working
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
if controlnet is not None:
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
else:
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
# perform guidance

View File

@@ -1,4 +1,5 @@
import torch
import safetensors
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
@@ -7,9 +8,11 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
@@ -163,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
# model_version is reserved for future use
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
if disable_mmap:
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
else:
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:

View File

@@ -0,0 +1,272 @@
# some parts are modified from Diffusers library (Apache License 2.0)
import math
from types import SimpleNamespace
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sdxl_original_unet
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
class ControlNetConditioningEmbedding(nn.Module):
def __init__(self):
super().__init__()
dims = [16, 32, 96, 256]
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(dims) - 1):
channel_in = dims[i]
channel_out = dims[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
nn.init.zeros_(self.conv_out.weight) # zero module weight
nn.init.zeros_(self.conv_out.bias) # zero module bias
def forward(self, x):
x = self.conv_in(x)
x = F.silu(x)
for block in self.blocks:
x = block(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
def __init__(self, multiplier: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.multiplier = multiplier
# remove unet layers
self.output_blocks = nn.ModuleList([])
del self.out
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
self.controlnet_down_blocks = nn.ModuleList([])
for dim in dims:
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
unet_sd = unet.state_dict()
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
sd = super().state_dict()
sd.update(unet_sd)
info = super().load_state_dict(sd, strict=True, assign=True)
return info
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
# convert state_dict to SAI format
unet_sd = {}
for k in list(state_dict.keys()):
if not k.startswith("controlnet_"):
unet_sd[k] = state_dict.pop(k)
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
state_dict.update(unet_sd)
super().load_state_dict(state_dict, strict=strict, assign=assign)
def state_dict(self, destination=None, prefix="", keep_vars=False):
# convert state_dict to Diffusers format
state_dict = super().state_dict(destination, prefix, keep_vars)
control_net_sd = {}
for k in list(state_dict.keys()):
if k.startswith("controlnet_"):
control_net_sd[k] = state_dict.pop(k)
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
state_dict.update(control_net_sd)
return state_dict
def forward(
self,
x: torch.Tensor,
timesteps: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
cond_image: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
multiplier = self.multiplier if self.multiplier is not None else 1.0
hs = []
for i, module in enumerate(self.input_blocks):
h = call_module(module, h, emb, context)
if i == 0:
h = self.controlnet_cond_embedding(cond_image) + h
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
h = call_module(self.middle_block, h, emb, context)
h = self.controlnet_mid_block(h) * multiplier
return hs, h
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
"""
This class is for training purpose only.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
h = h + mid_add
for module in self.output_blocks:
resi = hs.pop() + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
if __name__ == "__main__":
import time
logger.info("create unet")
unet = SdxlControlledUNet()
unet.to("cuda", torch.bfloat16)
unet.set_use_sdpa(True)
unet.set_gradient_checkpointing(True)
unet.train()
logger.info("create control_net")
control_net = SdxlControlNet()
control_net.to("cuda")
control_net.set_use_sdpa(True)
control_net.set_gradient_checkpointing(True)
control_net.train()
logger.info("Initialize control_net from unet")
control_net.init_from_unet(unet)
unet.requires_grad_(False)
control_net.requires_grad_(True)
# 使用メモリ量確認用の疑似学習ループ
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# import transformers
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
scaler = torch.cuda.amp.GradScaler(enabled=True)
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
txt = torch.randn(batch_size, 77, 2048).cuda()
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
output = unet(x, t, txt, vector, input_resi_add, mid_add)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info("finish training")
sd = control_net.state_dict()
from safetensors.torch import save_file
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")

View File

@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
@@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
"""
_self = self.delegate
@@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
if mid_add is not None:
h = h + mid_add
for module in _self.output_blocks:
# Deep Shrink
@@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
resi = hs.pop()
if input_resi_add is not None:
resi = resi + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0

View File

@@ -5,16 +5,18 @@ from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -44,6 +46,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors,
)
# work on low-ram device
@@ -60,7 +63,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
@@ -75,7 +78,7 @@ def _load_target_model(
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
@@ -323,7 +326,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
@@ -332,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
@@ -355,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
# assert (
# not hasattr(args, "weighted_captions") or not args.weighted_captions
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
@@ -369,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
def sample_images(*args, **kwargs):
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

570
library/strategy_base.py Normal file
View File

@@ -0,0 +1,570 @@
# base class for platform strategies. this file defines the interface for strategies
import os
import re
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
_re_attention = re.compile(
r"""\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError
def _get_weighted_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
max_length includes starting and ending tokens.
"""
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in TokenizeStrategy._re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(text: str, max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
No padding, starting or ending token is included.
"""
truncated = False
texts_and_weights = parse_prompt_attention(text)
tokens = []
weights = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
tokens += token
# copy the weight by length of token
weights += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(tokens) > max_length:
truncated = True
break
# truncate
if len(tokens) > max_length:
truncated = True
tokens = tokens[:max_length]
weights = weights[:max_length]
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
return tokens, weights
if max_length is None:
max_length = tokenizer.model_max_length
tokens, weights = get_prompts_with_weights(text, max_length - 2)
tokens, weights = pad_tokens_and_weights(
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
)
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
def _get_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
if weighted:
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
else:
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
if weighted:
weights = weights.squeeze(0)
new_weights = torch.ones(input_ids.shape)
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
b = i // (tokenizer.model_max_length - 2)
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
weights = new_weights
if weighted:
return input_ids, weights
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:param weights: list of weight tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
self._is_weighted = is_weighted
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def is_partial(self):
return self._is_partial
@property
def is_weighted(self):
return self._is_weighted
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
kwargs = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)

271
library/strategy_flux.py Normal file
View File

@@ -0,0 +1,271 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, t5_tokens, t5_attn_mask]
class FluxTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
# clip_l is None when using T5 only
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for FluxTokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = FluxTokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

171
library/strategy_sd.py Normal file
View File

@@ -0,0 +1,171 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
for t in text:
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
tokens = tokens.to(text_encoder.device)
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return [encoder_hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
weights = weights_list[0].to(encoder_hidden_states.device)
# apply weights
if weights.shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for i in range(weights.shape[1]):
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
:, i, 1:-1
].unsqueeze(-1)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

420
library/strategy_sd3.py Normal file
View File

@@ -0,0 +1,420 @@
import os
import glob
import random
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_attn_mask = l_tokens["attention_mask"]
g_attn_mask = g_tokens["attention_mask"]
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: Optional[CLIPTextModel]
clip_g: Optional[CLIPTextModelWithProjection]
t5xxl: Optional[T5EncoderModel]
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
batch_size, l_seq_len = l_tokens.shape
g_seq_len = g_tokens.shape[1]
non_drop_l_indices = []
non_drop_g_indices = []
for i in range(l_tokens.shape[0]):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if not drop_l:
non_drop_l_indices.append(i)
if not drop_g:
non_drop_g_indices.append(i)
# filter out dropped members
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
l_tokens = l_tokens[non_drop_l_indices]
l_attn_mask = l_attn_mask[non_drop_l_indices]
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
g_tokens = g_tokens[non_drop_g_indices]
g_attn_mask = g_attn_mask[non_drop_g_indices]
# call clip_l for non-dropped members
if len(non_drop_l_indices) > 0:
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
prompt_embeds = clip_l(
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_l_pooled = prompt_embeds[0]
nd_l_out = prompt_embeds.hidden_states[-2]
if len(non_drop_g_indices) > 0:
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
prompt_embeds = clip_g(
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_g_pooled = prompt_embeds[0]
nd_g_out = prompt_embeds.hidden_states[-2]
# fill in the dropped members
if len(non_drop_l_indices) == batch_size:
l_pooled = nd_l_pooled
l_out = nd_l_out
else:
# model output is always float32 because of the models are wrapped with Accelerator
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
if len(non_drop_l_indices) > 0:
l_pooled[non_drop_l_indices] = nd_l_pooled
l_out[non_drop_l_indices] = nd_l_out
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
if len(non_drop_g_indices) == batch_size:
g_pooled = nd_g_pooled
g_out = nd_g_out
else:
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
if len(non_drop_g_indices) > 0:
g_pooled[non_drop_g_indices] = nd_g_pooled
g_out[non_drop_g_indices] = nd_g_out
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is None or t5_tokens is None:
t5_out = None
t5_attn_mask = None
else:
# drop some members of the batch: we do not call t5xxl for dropped members
batch_size, t5_seq_len = t5_tokens.shape
non_drop_t5_indices = []
for i in range(t5_tokens.shape[0]):
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if not drop_t5:
non_drop_t5_indices.append(i)
# filter out dropped members
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
t5_tokens = t5_tokens[non_drop_t5_indices]
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
# call t5xxl for non-dropped members
if len(non_drop_t5_indices) > 0:
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
nd_t5_out, _ = t5xxl(
t5_tokens.to(t5xxl.device),
nd_t5_attn_mask if apply_t5_attn_mask else None,
return_dict=False,
output_hidden_states=True,
)
# fill in the dropped members
if len(non_drop_t5_indices) == batch_size:
t5_out = nd_t5_out
else:
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
if len(non_drop_t5_indices) > 0:
t5_out[non_drop_t5_indices] = nd_t5_out
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if t5_out is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"]
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

306
library/strategy_sdxl.py Normal file
View File

@@ -0,0 +1,306 @@
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
if max_length is None:
self.max_length = self.tokenizer1.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return (
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens1_list, tokens2_list = [], []
weights1_list, weights2_list = [], []
for t in text:
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
tokens1_list.append(tokens1)
tokens2_list.append(tokens2)
weights1_list.append(weights1)
weights2_list.append(weights2)
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
torch.stack(weights1_list, dim=0),
torch.stack(weights2_list, dim=0),
]
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def _pool_workaround(
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
]
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def _get_hidden_states_sdxl(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
if input_ids1.size()[1] == 1:
max_token_length = None
else:
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
input_ids2 = input_ids2.to(text_encoder2.device)
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
return hidden_states1, hidden_states2, pool2
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
text_encoder1, text_encoder2 = models
unwrapped_text_encoder2 = None
else:
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
tokens1, tokens2 = tokens
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
)
return [hidden_states1, hidden_states2, pool2]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
# apply weights
if weights_list[0].shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
for i in range(weight.shape[1]):
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
:, i, 1:-1
].unsqueeze(-1)
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens_list, weights_list
)
else:
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,29 @@
import logging
import sys
import threading
import torch
from torchvision import transforms
from typing import *
import json
import struct
import torch
import torch.nn as nn
from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -79,10 +90,315 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
# Convert back to cv2 format
if has_alpha:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
else:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix