Merge branch 'sd3' into fast_image_sizes

This commit is contained in:
Kohya S
2024-10-01 08:38:09 +09:00
38 changed files with 3907 additions and 888 deletions

View File

@@ -2,9 +2,12 @@
# license: Apache-2.0 License
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
import math
from typing import Optional
import os
import time
from typing import Dict, List, Optional
from library.device_utils import init_ipex, clean_memory_on_device
@@ -752,18 +755,6 @@ class DoubleStreamBlock(nn.Module):
else:
return self._forward(img, txt, vec, pe, txt_attention_mask)
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
# )
# else:
# return self._forward(img, txt, vec, pe)
class SingleStreamBlock(nn.Module):
"""
@@ -809,7 +800,7 @@ class SingleStreamBlock(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -817,16 +808,35 @@ class SingleStreamBlock(nn.Module):
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# make attention mask if not None
attn_mask = None
if txt_attention_mask is not None:
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(
attn_mask,
torch.ones(
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
),
),
dim=1,
) # b, seq_len + img_len = x_len
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
# compute attention
attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
# cpu offload checkpointing
@@ -838,19 +848,11 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
)
else:
return self._forward(x, vec, pe)
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
# if self.training and self.gradient_checkpointing:
# def create_custom_forward(func):
# def custom_forward(*inputs):
# return func(*inputs)
# return custom_forward
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
# else:
# return self._forward(x, vec, pe)
return self._forward(x, vec, pe, txt_attention_mask)
class LastLayer(nn.Module):
@@ -918,8 +920,10 @@ class Flux(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
self.double_blocks_to_swap = None
self.single_blocks_to_swap = None
self.blocks_to_swap = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
@property
def device(self):
@@ -957,38 +961,53 @@ class Flux(nn.Module):
print("FLUX: Gradient checkpointing disabled.")
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
self.double_blocks_to_swap = double_blocks
self.single_blocks_to_swap = single_blocks
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
n = 1 # async block swap. 1 is enough
# n = 2
# n = max(1, os.cpu_count() // 2)
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.double_blocks_to_swap:
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
self.double_blocks = None
if self.single_blocks_to_swap:
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.to(device)
if self.double_blocks_to_swap:
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
if self.single_blocks_to_swap:
self.single_blocks = save_single_blocks
def get_block_unit(self, index: int):
if index < len(self.double_blocks):
return (self.double_blocks[index],)
else:
index -= len(self.double_blocks)
index *= 2
return self.single_blocks[index], self.single_blocks[index + 1]
def get_unit_index(self, is_double: bool, index: int):
if is_double:
return index
else:
return len(self.double_blocks) + index // 2
def prepare_block_swap_before_forward(self):
# move last n blocks to cpu: they are on cuda
if self.double_blocks_to_swap:
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
self.double_blocks[i].to(self.device)
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
self.double_blocks[i].to("cpu") # , non_blocking=True)
if self.single_blocks_to_swap:
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
self.single_blocks[i].to(self.device)
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
self.single_blocks[i].to("cpu") # , non_blocking=True)
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
for i in range(self.num_block_units - self.blocks_to_swap):
for b in self.get_block_unit(i):
b.to(self.device)
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
for b in self.get_block_unit(i):
b.to("cpu")
clean_memory_on_device(self.device)
def forward(
@@ -1018,69 +1037,73 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if not self.double_blocks_to_swap:
if not self.blocks_to_swap:
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.double_blocks_to_swap):
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
if block.parameters().__next__().device.type != "cpu":
block.to("cpu") # , non_blocking=True)
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
futures = {}
block = self.double_blocks[block_idx]
if block.parameters().__next__().device.type == "cpu":
block.to(self.device)
# print(f"Moved double block {block_idx} to cuda.")
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
for block in blocks_to_cpu:
block.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
for block in blocks_to_cuda:
block.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
to_cpu_block_index = 0
for block_idx, block in enumerate(self.double_blocks):
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
if moving:
block.to(self.device) # move to cuda
# print(f"Moved double block {block_idx} to cuda.")
# print(f"Double block {block_idx}")
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
wait_for_blocks_move(unit_idx, futures)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
# print(f"Moved double block {to_cpu_block_index} to cpu.")
to_cpu_block_index += 1
if unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
img = torch.cat((txt, img), 1)
img = torch.cat((txt, img), 1)
if not self.single_blocks_to_swap:
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
else:
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
for block_idx in range(self.single_blocks_to_swap):
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
if block.parameters().__next__().device.type != "cpu":
block.to("cpu") # , non_blocking=True)
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
block = self.single_blocks[block_idx]
if block.parameters().__next__().device.type == "cpu":
block.to(self.device)
# print(f"Moved single block {block_idx} to cuda.")
to_cpu_block_index = 0
for block_idx, block in enumerate(self.single_blocks):
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
if moving:
block.to(self.device) # move to cuda
# print(f"Moved single block {block_idx} to cuda.")
# print(f"Single block {block_idx}")
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
if block_idx % 2 == 0:
wait_for_blocks_move(unit_idx, futures)
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
if moving:
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
# print(f"Moved single block {to_cpu_block_index} to cpu.")
to_cpu_block_index += 1
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
block_idx_to_cpu = unit_idx
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
img = img[:, txt.shape[1] :, ...]
@@ -1089,6 +1112,7 @@ class Flux(nn.Module):
vec = vec.to(self.device)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
@@ -1250,10 +1274,11 @@ class FluxLower(nn.Module):
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

View File

@@ -58,7 +58,7 @@ def sample_images(
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -66,7 +66,8 @@ def sample_images(
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
@@ -84,7 +85,7 @@ def sample_images(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
@@ -134,7 +135,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: List[CLIPTextModel],
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
@@ -186,14 +187,26 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
te_outputs = sample_prompts_te_outputs[prompt]
else:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -240,17 +253,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
def time_shift(mu: float, sigma: float, t: torch.Tensor):
@@ -297,6 +312,7 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
pred = model(
img=img,
img_ids=img_ids,
@@ -309,7 +325,8 @@ def denoise(
)
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
@@ -370,7 +387,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -380,9 +397,30 @@ def get_noisy_model_input_and_timesteps(
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -559,9 +597,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid"],
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法sigma、random uniform、またはrandom normalのsigmoid。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",

View File

@@ -1,5 +1,5 @@
import json
from typing import Union
from typing import Optional, Union
import einops
import torch
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
# temporary copy from sd3_utils TODO refactor
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
def load_flow_model(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
model = flux_models.Flux(flux_models.configs[name].params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
@@ -167,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
return clip
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
@@ -213,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]

View File

@@ -604,17 +604,19 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
# region Diffusers

View File

@@ -5,8 +5,7 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import sd3_utils, train_util
from library import sd3_models
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -60,7 +59,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
@@ -81,6 +80,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
@@ -99,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
@@ -143,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]

View File

@@ -13,6 +13,7 @@ import shutil
import time
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
@@ -44,7 +45,11 @@ from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers.optimization import (
SchedulerType as DiffusersSchedulerType,
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
)
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
StableDiffusionPipeline,
DDPMScheduler,
@@ -73,7 +78,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
@@ -656,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
def adjust_min_max_bucket_reso_by_steps(
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
) -> Tuple[int, int]:
# make min/max bucket reso to be multiple of bucket_reso_steps
if min_bucket_reso % bucket_reso_steps != 0:
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
logger.warning(
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
)
min_bucket_reso = adjusted_min_bucket_reso
if max_bucket_reso % bucket_reso_steps != 0:
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
logger.warning(
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
)
max_bucket_reso = adjusted_max_bucket_reso
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
return min_bucket_reso, max_bucket_reso
def set_seed(self, seed):
self.seed = seed
@@ -988,9 +1021,26 @@ class BaseDataset(torch.utils.data.Dataset):
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
# split by resolution and some conditions
class Condition:
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
self.reso = reso
self.flip_aug = flip_aug
self.alpha_mask = alpha_mask
self.random_crop = random_crop
def __eq__(self, other):
return (
self.reso == other.reso
and self.flip_aug == other.flip_aug
and self.alpha_mask == other.alpha_mask
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
@@ -1011,20 +1061,23 @@ class BaseDataset(torch.utils.data.Dataset):
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append(batch)
batches.append((current_condition, batch))
batch = []
current_condition = None
if len(batch) > 0:
batches.append(batch)
batches.append((current_condition, batch))
# if cache to disk, don't cache latents in non-main process, set to info only
if caching_strategy.cache_to_disk and not is_main_process:
@@ -1036,9 +1089,8 @@ class BaseDataset(torch.utils.data.Dataset):
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -1049,9 +1101,26 @@ class BaseDataset(torch.utils.data.Dataset):
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
# split by resolution and some conditions
class Condition:
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
self.reso = reso
self.flip_aug = flip_aug
self.alpha_mask = alpha_mask
self.random_crop = random_crop
def __eq__(self, other):
return (
self.reso == other.reso
and self.flip_aug == other.flip_aug
and self.alpha_mask == other.alpha_mask
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
@@ -1072,28 +1141,31 @@ class BaseDataset(torch.utils.data.Dataset):
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = []
batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batches.append((current_condition, batch))
batch = []
current_condition = None
if len(batch) > 0:
batches.append(batch)
batches.append((current_condition, batch))
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
r"""
@@ -1663,12 +1735,9 @@ class DreamBoothDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -1708,7 +1777,7 @@ class DreamBoothDataset(BaseDataset):
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
logger.warning(f"not directory: {subset.image_dir}")
return [], []
return [], [], []
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
@@ -2062,6 +2131,9 @@ class FineTuningDataset(BaseDataset):
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
@@ -2284,9 +2356,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = cv2.resize(
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
)
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2425,7 +2495,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != reso: # HxW
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
return False
else:
if "alpha_mask" in npz:
@@ -2534,7 +2604,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if "alpha_masks" in example and example["alpha_masks"] is not None:
alpha_mask = example["alpha_masks"][j]
logger.info(f"alpha mask size: {alpha_mask.size()}")
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
if os.name == "nt":
cv2.imshow("alpha_mask", alpha_mask)
@@ -2680,7 +2750,10 @@ def trim_and_resize_if_required(
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image_height, image_width = image.shape[0:2]
@@ -3270,11 +3343,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
def add_optimizer_arguments(parser: argparse.ArgumentParser):
def int_or_float(value):
if value.endswith("%"):
try:
return float(value[:-1]) / 100.0
except ValueError:
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
try:
float_value = float(value)
if float_value >= 1:
return int(value)
return float(value)
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
parser.add_argument(
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, "
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
"AdaFactor. "
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.",
)
# backward compatibility
@@ -3305,6 +3396,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
)
# parser.add_argument(
# "--optimizer_schedulefree_wrapper",
# action="store_true",
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
# )
# parser.add_argument(
# "--schedulefree_wrapper_args",
# type=str,
# default=None,
# nargs="*",
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ..."',
# )
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
@@ -3322,9 +3427,17 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
type=int_or_float,
default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0",
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
" / 学習率のスケジューラをウォームアップするステップ数デフォルト0、または学習ステップの比率1未満のfloat値の場合",
)
parser.add_argument(
"--lr_decay_steps",
type=int_or_float,
default=0,
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
" / 学習率のスケジューラを減衰させるステップ数デフォルト0、または学習ステップの比率1未満のfloat値の場合",
)
parser.add_argument(
"--lr_scheduler_num_cycles",
@@ -3344,6 +3457,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
)
parser.add_argument(
"--lr_scheduler_timescale",
type=int,
default=None,
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
)
parser.add_argument(
"--lr_scheduler_min_lr_ratio",
type=float,
default=None,
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -4071,8 +4198,20 @@ def add_dataset_arguments(
action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument(
"--min_bucket_reso",
type=int,
default=256,
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--max_bucket_reso",
type=int,
default=1024,
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
)
parser.add_argument(
"--bucket_reso_steps",
type=int,
@@ -4290,7 +4429,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
@@ -4343,6 +4482,7 @@ def get_optimizer(args, trainable_params):
lr = args.learning_rate
optimizer = None
optimizer_class = None
if optimizer_type == "Lion".lower():
try:
@@ -4400,7 +4540,8 @@ def get_optimizer(args, trainable_params):
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if optimizer_class is not None:
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW".lower():
logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}")
@@ -4562,26 +4703,159 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
optimizer.train()
if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]
case_sensitive_optimizer_type = args.optimizer_type # not lower
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
optimizer_class = getattr(optimizer_module, optimizer_type)
if "." not in case_sensitive_optimizer_type: # from torch.optim
optimizer_module = torch.optim
else: # from other library
values = case_sensitive_optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
case_sensitive_optimizer_type = values[-1]
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
"""
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
schedulefree_wrapper_kwargs = {}
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
sf_wrapper.train() # make optimizer as train mode
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
class OptimizerProxy(torch.optim.Optimizer):
def __init__(self, sf_wrapper):
self._sf_wrapper = sf_wrapper
def __getattr__(self, name):
return getattr(self._sf_wrapper, name)
# override properties
@property
def state(self):
return self._sf_wrapper.state
@state.setter
def state(self, state):
self._sf_wrapper.state = state
@property
def param_groups(self):
return self._sf_wrapper.param_groups
@param_groups.setter
def param_groups(self, param_groups):
self._sf_wrapper.param_groups = param_groups
@property
def defaults(self):
return self._sf_wrapper.defaults
@defaults.setter
def defaults(self, defaults):
self._sf_wrapper.defaults = defaults
def add_param_group(self, param_group):
self._sf_wrapper.add_param_group(param_group)
def load_state_dict(self, state_dict):
self._sf_wrapper.load_state_dict(state_dict)
def state_dict(self):
return self._sf_wrapper.state_dict()
def zero_grad(self):
self._sf_wrapper.zero_grad()
def step(self, closure=None):
self._sf_wrapper.step(closure)
def train(self):
self._sf_wrapper.train()
def eval(self):
self._sf_wrapper.eval()
# isinstance チェックをパスするためのメソッド
def __instancecheck__(self, instance):
return isinstance(instance, (type(self), Optimizer))
optimizer = OptimizerProxy(sf_wrapper)
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
"""
# for logging
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
return optimizer_name, optimizer_args, optimizer
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
if not is_schedulefree_optimizer(optimizer, args):
# return dummy func
return lambda: None, lambda: None
# get train and eval functions from optimizer
train_fn = optimizer.train
eval_fn = optimizer.eval
return train_fn, eval_fn
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
# this scheduler is used for logging only.
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
class DummyScheduler:
def __init__(self, optimizer: Optimizer):
self.optimizer = optimizer
def step(self):
pass
def get_last_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
return DummyScheduler(optimizer)
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
# Add some checking and features to the original function.
@@ -4590,11 +4864,23 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
# if schedulefree optimizer, return dummy scheduler
if is_schedulefree_optimizer(optimizer, args):
return get_dummy_scheduler(optimizer)
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_warmup_steps: Optional[int] = (
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
)
num_decay_steps: Optional[int] = (
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
)
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
timescale = args.lr_scheduler_timescale
min_lr_ratio = args.lr_scheduler_min_lr_ratio
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
@@ -4630,15 +4916,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
# logger.info(f"adafactor scheduler init lr {initial_lr}")
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
name = DiffusersSchedulerType(name)
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
@@ -4646,6 +4934,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
@@ -4664,7 +4955,46 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
if name == SchedulerType.COSINE_WITH_MIN_LR:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles / 2,
min_lr_rate=min_lr_ratio,
**lr_scheduler_kwargs,
)
# these schedulers do not require `num_decay_steps`
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**lr_scheduler_kwargs,
)
# All other schedulers require `num_decay_steps`
if num_decay_steps is None:
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
if name == SchedulerType.WARMUP_STABLE_DECAY:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_stable_steps=num_stable_steps,
num_decay_steps=num_decay_steps,
num_cycles=num_cycles / 2,
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
**lr_scheduler_kwargs,
)
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_decay_steps=num_decay_steps,
**lr_scheduler_kwargs,
)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
@@ -5312,34 +5642,27 @@ def save_sd_model_on_train_end_common(
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
huber_c = torch.full((b_size,), args.huber_c)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
timesteps = timesteps.repeat(b_size).to(device)
huber_c = huber_c.to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
huber_c = None # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()
timesteps = timesteps.long().to(device)
return timesteps, huber_c
@@ -5378,21 +5701,22 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
return noise, noisy_latents, timesteps, huber_c
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
):
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
@@ -5678,7 +6002,7 @@ def sample_images_common(
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
@@ -5712,11 +6036,13 @@ def sample_image_inference(
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if torch.cuda.is_available():
torch.cuda.seed()
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
@@ -5751,8 +6077,9 @@ def sample_image_inference(
controlnet_image=controlnet_image,
)
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
image = pipeline.latents_to_image(latents)[0]
@@ -5766,17 +6093,14 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# endregion

View File

@@ -10,6 +10,9 @@ from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np
def fire_in_thread(f, *args, **kwargs):
@@ -82,6 +85,66 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
@@ -198,7 +261,7 @@ class MemoryEfficientSafeOpen:
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
@@ -241,6 +304,24 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
# Convert back to cv2 format
if has_alpha:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
else:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
return resized_cv2
# TODO make inf_utils.py