mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into fast_image_sizes
This commit is contained in:
@@ -2,9 +2,12 @@
|
||||
# license: Apache-2.0 License
|
||||
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Optional
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
@@ -752,18 +755,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
return self._forward(img, txt, vec, pe, txt_attention_mask)
|
||||
|
||||
# def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(
|
||||
# create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=USE_REENTRANT
|
||||
# )
|
||||
# else:
|
||||
# return self._forward(img, txt, vec, pe)
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
@@ -809,7 +800,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -817,16 +808,35 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# make attention mask if not None
|
||||
attn_mask = None
|
||||
if txt_attention_mask is not None:
|
||||
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
||||
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
||||
attn_mask = torch.cat(
|
||||
(
|
||||
attn_mask,
|
||||
torch.ones(
|
||||
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
) # b, seq_len + img_len = x_len
|
||||
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
|
||||
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
|
||||
# cpu offload checkpointing
|
||||
|
||||
@@ -838,19 +848,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=False)
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
return self._forward(x, vec, pe)
|
||||
|
||||
# def forward(self, x: Tensor, vec: Tensor, pe: Tensor):
|
||||
# if self.training and self.gradient_checkpointing:
|
||||
# def create_custom_forward(func):
|
||||
# def custom_forward(*inputs):
|
||||
# return func(*inputs)
|
||||
# return custom_forward
|
||||
# return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, vec, pe, use_reentrant=USE_REENTRANT)
|
||||
# else:
|
||||
# return self._forward(x, vec, pe)
|
||||
return self._forward(x, vec, pe, txt_attention_mask)
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
@@ -918,8 +920,10 @@ class Flux(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
self.double_blocks_to_swap = None
|
||||
self.single_blocks_to_swap = None
|
||||
self.blocks_to_swap = None
|
||||
|
||||
self.thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
self.num_block_units = len(self.double_blocks) + len(self.single_blocks) // 2
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -957,38 +961,53 @@ class Flux(nn.Module):
|
||||
|
||||
print("FLUX: Gradient checkpointing disabled.")
|
||||
|
||||
def enable_block_swap(self, double_blocks: Optional[int], single_blocks: Optional[int]):
|
||||
self.double_blocks_to_swap = double_blocks
|
||||
self.single_blocks_to_swap = single_blocks
|
||||
def enable_block_swap(self, num_blocks: int):
|
||||
self.blocks_to_swap = num_blocks
|
||||
|
||||
n = 1 # async block swap. 1 is enough
|
||||
# n = 2
|
||||
# n = max(1, os.cpu_count() // 2)
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=n)
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
# assume model is on cpu
|
||||
if self.double_blocks_to_swap:
|
||||
if self.blocks_to_swap:
|
||||
save_double_blocks = self.double_blocks
|
||||
self.double_blocks = None
|
||||
if self.single_blocks_to_swap:
|
||||
save_single_blocks = self.single_blocks
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
|
||||
if self.double_blocks_to_swap:
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
if self.single_blocks_to_swap:
|
||||
self.single_blocks = save_single_blocks
|
||||
|
||||
def get_block_unit(self, index: int):
|
||||
if index < len(self.double_blocks):
|
||||
return (self.double_blocks[index],)
|
||||
else:
|
||||
index -= len(self.double_blocks)
|
||||
index *= 2
|
||||
return self.single_blocks[index], self.single_blocks[index + 1]
|
||||
|
||||
def get_unit_index(self, is_double: bool, index: int):
|
||||
if is_double:
|
||||
return index
|
||||
else:
|
||||
return len(self.double_blocks) + index // 2
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# move last n blocks to cpu: they are on cuda
|
||||
if self.double_blocks_to_swap:
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap):
|
||||
self.double_blocks[i].to(self.device)
|
||||
for i in range(len(self.double_blocks) - self.double_blocks_to_swap, len(self.double_blocks)):
|
||||
self.double_blocks[i].to("cpu") # , non_blocking=True)
|
||||
if self.single_blocks_to_swap:
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap):
|
||||
self.single_blocks[i].to(self.device)
|
||||
for i in range(len(self.single_blocks) - self.single_blocks_to_swap, len(self.single_blocks)):
|
||||
self.single_blocks[i].to("cpu") # , non_blocking=True)
|
||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# raise ValueError("Block swap is not enabled.")
|
||||
return
|
||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to(self.device)
|
||||
for i in range(self.num_block_units - self.blocks_to_swap, self.num_block_units):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to("cpu")
|
||||
clean_memory_on_device(self.device)
|
||||
|
||||
def forward(
|
||||
@@ -1018,69 +1037,73 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
if not self.double_blocks_to_swap:
|
||||
if not self.blocks_to_swap:
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.double_blocks_to_swap):
|
||||
block = self.double_blocks[len(self.double_blocks) - self.double_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {len(self.double_blocks) - self.double_blocks_to_swap + block_idx} to cpu.")
|
||||
futures = {}
|
||||
|
||||
block = self.double_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
|
||||
def move_blocks(bidx_to_cpu, blocks_to_cpu, bidx_to_cuda, blocks_to_cuda):
|
||||
# print(f"Moving {bidx_to_cpu} to cpu.")
|
||||
for block in blocks_to_cpu:
|
||||
block.to("cpu", non_blocking=True)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# print(f"Moving {bidx_to_cuda} to cuda.")
|
||||
for block in blocks_to_cuda:
|
||||
block.to(self.device, non_blocking=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
|
||||
return block_idx_to_cpu, block_idx_to_cuda
|
||||
|
||||
blocks_to_cpu = self.get_block_unit(block_idx_to_cpu)
|
||||
blocks_to_cuda = self.get_block_unit(block_idx_to_cuda)
|
||||
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
|
||||
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, blocks_to_cpu, block_idx_to_cuda, blocks_to_cuda)
|
||||
|
||||
def wait_for_blocks_move(block_idx, ftrs):
|
||||
if block_idx not in ftrs:
|
||||
return
|
||||
# print(f"Waiting for move blocks: {block_idx}")
|
||||
# start_time = time.perf_counter()
|
||||
ftr = ftrs.pop(block_idx)
|
||||
ftr.result()
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.double_blocks) - self.double_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved double block {block_idx} to cuda.")
|
||||
# print(f"Double block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=True, index=block_idx)
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.double_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved double block {to_cpu_block_index} to cpu.")
|
||||
to_cpu_block_index += 1
|
||||
if unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if not self.single_blocks_to_swap:
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
else:
|
||||
# make sure first n blocks are on cuda, and last n blocks are on cpu at beginning
|
||||
for block_idx in range(self.single_blocks_to_swap):
|
||||
block = self.single_blocks[len(self.single_blocks) - self.single_blocks_to_swap + block_idx]
|
||||
if block.parameters().__next__().device.type != "cpu":
|
||||
block.to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {len(self.single_blocks) - self.single_blocks_to_swap + block_idx} to cpu.")
|
||||
|
||||
block = self.single_blocks[block_idx]
|
||||
if block.parameters().__next__().device.type == "cpu":
|
||||
block.to(self.device)
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
|
||||
to_cpu_block_index = 0
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
# move last n blocks to cuda: they are on cpu, and move first n blocks to cpu: they are on cuda
|
||||
moving = block_idx >= len(self.single_blocks) - self.single_blocks_to_swap
|
||||
if moving:
|
||||
block.to(self.device) # move to cuda
|
||||
# print(f"Moved single block {block_idx} to cuda.")
|
||||
# print(f"Single block {block_idx}")
|
||||
unit_idx = self.get_unit_index(is_double=False, index=block_idx)
|
||||
if block_idx % 2 == 0:
|
||||
wait_for_blocks_move(unit_idx, futures)
|
||||
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
|
||||
if moving:
|
||||
self.single_blocks[to_cpu_block_index].to("cpu") # , non_blocking=True)
|
||||
# print(f"Moved single block {to_cpu_block_index} to cpu.")
|
||||
to_cpu_block_index += 1
|
||||
if block_idx % 2 == 1 and unit_idx < self.blocks_to_swap:
|
||||
block_idx_to_cpu = unit_idx
|
||||
block_idx_to_cuda = self.num_block_units - self.blocks_to_swap + unit_idx
|
||||
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
|
||||
futures[block_idx_to_cuda] = future
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
@@ -1089,6 +1112,7 @@ class Flux(nn.Module):
|
||||
vec = vec.to(self.device)
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@@ -1250,10 +1274,11 @@ class FluxLower(nn.Module):
|
||||
txt: Tensor,
|
||||
vec: Tensor | None = None,
|
||||
pe: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -58,7 +58,7 @@ def sample_images(
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
@@ -66,7 +66,8 @@ def sample_images(
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
if text_encoders is not None:
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
@@ -84,7 +85,7 @@ def sample_images(
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
@@ -134,7 +135,7 @@ def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
flux: flux_models.Flux,
|
||||
text_encoders: List[CLIPTextModel],
|
||||
text_encoders: Optional[List[CLIPTextModel]],
|
||||
ae: flux_models.AutoEncoder,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
@@ -186,14 +187,26 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
te_outputs = sample_prompts_te_outputs[prompt]
|
||||
else:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = te_outputs
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -240,17 +253,19 @@ def sample_image_inference(
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
import wandb
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(
|
||||
image,
|
||||
caption=prompt # positive prompt as a caption
|
||||
)},
|
||||
commit=False
|
||||
)
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
@@ -297,6 +312,7 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -309,7 +325,8 @@ def denoise(
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
@@ -370,7 +387,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz = latents.shape[0]
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -380,9 +397,30 @@ def get_noisy_model_input_and_timesteps(
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -559,9 +597,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid"],
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
default="sigma",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
|
||||
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
import einops
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
|
||||
|
||||
# temporary copy from sd3_utils TODO refactor
|
||||
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
):
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.Flux:
|
||||
logger.info(f"Building Flux model {name}")
|
||||
with torch.device("meta"):
|
||||
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||
model = flux_models.Flux(flux_models.configs[name].params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
@@ -167,7 +171,9 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> T5EncoderModel:
|
||||
def load_t5xxl(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
@@ -213,6 +219,11 @@ def load_t5xxl(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.devi
|
||||
return t5xxl
|
||||
|
||||
|
||||
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
|
||||
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
|
||||
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
|
||||
|
||||
|
||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||
|
||||
@@ -604,17 +604,19 @@ def sample_image_inference(
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
import wandb
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(
|
||||
image,
|
||||
caption=prompt # positive prompt as a caption
|
||||
)},
|
||||
commit=False
|
||||
)
|
||||
|
||||
|
||||
# region Diffusers
|
||||
|
||||
@@ -5,8 +5,7 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
@@ -60,7 +59,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
clip_l, t5xxl = models
|
||||
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||
|
||||
@@ -81,6 +80,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
else:
|
||||
t5_out = None
|
||||
txt_ids = None
|
||||
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
||||
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||
|
||||
@@ -99,6 +99,8 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
@@ -143,6 +145,14 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
logger.warning(
|
||||
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
|
||||
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
|
||||
)
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
@@ -44,7 +45,11 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers.optimization import (
|
||||
SchedulerType as DiffusersSchedulerType,
|
||||
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
||||
)
|
||||
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
DDPMScheduler,
|
||||
@@ -73,7 +78,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, pil_resize
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -656,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
def adjust_min_max_bucket_reso_by_steps(
|
||||
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
|
||||
) -> Tuple[int, int]:
|
||||
# make min/max bucket reso to be multiple of bucket_reso_steps
|
||||
if min_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
|
||||
)
|
||||
min_bucket_reso = adjusted_min_bucket_reso
|
||||
if max_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
|
||||
)
|
||||
max_bucket_reso = adjusted_max_bucket_reso
|
||||
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
|
||||
return min_bucket_reso, max_bucket_reso
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
@@ -988,9 +1021,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1011,20 +1061,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
@@ -1036,9 +1089,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -1049,9 +1101,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1072,28 +1141,31 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
r"""
|
||||
@@ -1663,12 +1735,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -1708,7 +1777,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
return [], [], []
|
||||
|
||||
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
|
||||
use_cached_info_for_subset = subset.cache_info
|
||||
@@ -2062,6 +2131,9 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -2284,9 +2356,7 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = cv2.resize(
|
||||
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -2425,7 +2495,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != reso: # HxW
|
||||
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
@@ -2534,7 +2604,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||
alpha_mask = example["alpha_masks"][j]
|
||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
||||
alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
|
||||
if os.name == "nt":
|
||||
cv2.imshow("alpha_mask", alpha_mask)
|
||||
|
||||
@@ -2680,7 +2750,10 @@ def trim_and_resize_if_required(
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
else:
|
||||
image = pil_resize(image, resized_size)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
@@ -3270,11 +3343,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
def int_or_float(value):
|
||||
if value.endswith("%"):
|
||||
try:
|
||||
return float(value[:-1]) / 100.0
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
|
||||
try:
|
||||
float_value = float(value)
|
||||
if float_value >= 1:
|
||||
return int(value)
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, "
|
||||
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
|
||||
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
|
||||
"AdaFactor. "
|
||||
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -3305,6 +3396,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--optimizer_schedulefree_wrapper",
|
||||
# action="store_true",
|
||||
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
|
||||
# )
|
||||
|
||||
# parser.add_argument(
|
||||
# "--schedulefree_wrapper_args",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# nargs="*",
|
||||
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
|
||||
# )
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_args",
|
||||
@@ -3322,9 +3427,17 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
|
||||
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
|
||||
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_decay_steps",
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
|
||||
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_num_cycles",
|
||||
@@ -3344,6 +3457,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
|
||||
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_timescale",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
|
||||
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_min_lr_ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
|
||||
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
|
||||
)
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
@@ -4071,8 +4198,20 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument(
|
||||
"--min_bucket_reso",
|
||||
type=int,
|
||||
default=256,
|
||||
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_bucket_reso",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
@@ -4290,7 +4429,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
@@ -4343,6 +4482,7 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
lr = args.learning_rate
|
||||
optimizer = None
|
||||
optimizer_class = None
|
||||
|
||||
if optimizer_type == "Lion".lower():
|
||||
try:
|
||||
@@ -4400,7 +4540,8 @@ def get_optimizer(args, trainable_params):
|
||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
if optimizer_class is not None:
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW".lower():
|
||||
logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
@@ -4562,26 +4703,159 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
if optimizer_type == "AdamWScheduleFree".lower():
|
||||
optimizer_class = sf.AdamWScheduleFree
|
||||
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "SGDScheduleFree".lower():
|
||||
optimizer_class = sf.SGDScheduleFree
|
||||
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
|
||||
optimizer.train()
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
||||
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
||||
optimizer_module = torch.optim
|
||||
else: # from other library
|
||||
values = case_sensitive_optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
case_sensitive_optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
"""
|
||||
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
|
||||
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
schedulefree_wrapper_kwargs = {}
|
||||
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
|
||||
for arg in args.schedulefree_wrapper_args:
|
||||
key, value = arg.split("=")
|
||||
value = ast.literal_eval(value)
|
||||
schedulefree_wrapper_kwargs[key] = value
|
||||
|
||||
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
|
||||
sf_wrapper.train() # make optimizer as train mode
|
||||
|
||||
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
|
||||
class OptimizerProxy(torch.optim.Optimizer):
|
||||
def __init__(self, sf_wrapper):
|
||||
self._sf_wrapper = sf_wrapper
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._sf_wrapper, name)
|
||||
|
||||
# override properties
|
||||
@property
|
||||
def state(self):
|
||||
return self._sf_wrapper.state
|
||||
|
||||
@state.setter
|
||||
def state(self, state):
|
||||
self._sf_wrapper.state = state
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self._sf_wrapper.param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, param_groups):
|
||||
self._sf_wrapper.param_groups = param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self._sf_wrapper.defaults
|
||||
|
||||
@defaults.setter
|
||||
def defaults(self, defaults):
|
||||
self._sf_wrapper.defaults = defaults
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self._sf_wrapper.add_param_group(param_group)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._sf_wrapper.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self):
|
||||
return self._sf_wrapper.state_dict()
|
||||
|
||||
def zero_grad(self):
|
||||
self._sf_wrapper.zero_grad()
|
||||
|
||||
def step(self, closure=None):
|
||||
self._sf_wrapper.step(closure)
|
||||
|
||||
def train(self):
|
||||
self._sf_wrapper.train()
|
||||
|
||||
def eval(self):
|
||||
self._sf_wrapper.eval()
|
||||
|
||||
# isinstance チェックをパスするためのメソッド
|
||||
def __instancecheck__(self, instance):
|
||||
return isinstance(instance, (type(self), Optimizer))
|
||||
|
||||
optimizer = OptimizerProxy(sf_wrapper)
|
||||
|
||||
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
|
||||
"""
|
||||
|
||||
# for logging
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
|
||||
if not is_schedulefree_optimizer(optimizer, args):
|
||||
# return dummy func
|
||||
return lambda: None, lambda: None
|
||||
|
||||
# get train and eval functions from optimizer
|
||||
train_fn = optimizer.train
|
||||
eval_fn = optimizer.eval
|
||||
|
||||
return train_fn, eval_fn
|
||||
|
||||
|
||||
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||
|
||||
|
||||
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
|
||||
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
||||
# this scheduler is used for logging only.
|
||||
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
||||
class DummyScheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def get_last_lr(self):
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
return DummyScheduler(optimizer)
|
||||
|
||||
|
||||
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||
# Add some checking and features to the original function.
|
||||
|
||||
@@ -4590,11 +4864,23 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
# if schedulefree optimizer, return dummy scheduler
|
||||
if is_schedulefree_optimizer(optimizer, args):
|
||||
return get_dummy_scheduler(optimizer)
|
||||
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_warmup_steps: Optional[int] = (
|
||||
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
|
||||
)
|
||||
num_decay_steps: Optional[int] = (
|
||||
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
|
||||
)
|
||||
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
timescale = args.lr_scheduler_timescale
|
||||
min_lr_ratio = args.lr_scheduler_min_lr_ratio
|
||||
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
@@ -4630,15 +4916,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
# logger.info(f"adafactor scheduler init lr {initial_lr}")
|
||||
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
||||
|
||||
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
|
||||
name = DiffusersSchedulerType(name)
|
||||
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||
|
||||
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
@@ -4646,6 +4934,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
||||
|
||||
if name == SchedulerType.INVERSE_SQRT:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
@@ -4664,7 +4955,46 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
|
||||
if name == SchedulerType.COSINE_WITH_MIN_LR:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_rate=min_lr_ratio,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# these schedulers do not require `num_decay_steps`
|
||||
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# All other schedulers require `num_decay_steps`
|
||||
if num_decay_steps is None:
|
||||
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
||||
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_stable_steps=num_stable_steps,
|
||||
num_decay_steps=num_decay_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_decay_steps=num_decay_steps,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
@@ -5312,34 +5642,27 @@ def save_sd_model_on_train_end_common(
|
||||
|
||||
|
||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||
|
||||
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
|
||||
# as. In the future there may be a smarter way
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
|
||||
timestep = timesteps.item()
|
||||
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
huber_c = math.exp(-alpha * timestep)
|
||||
huber_c = torch.exp(-alpha * timesteps)
|
||||
elif args.huber_schedule == "snr":
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
elif args.huber_schedule == "constant":
|
||||
huber_c = args.huber_c
|
||||
huber_c = torch.full((b_size,), args.huber_c)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
timesteps = timesteps.repeat(b_size).to(device)
|
||||
huber_c = huber_c.to(device)
|
||||
elif args.loss_type == "l2":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||
huber_c = 1 # may be anything, as it's not used
|
||||
huber_c = None # may be anything, as it's not used
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
||||
timesteps = timesteps.long()
|
||||
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps, huber_c
|
||||
|
||||
|
||||
@@ -5378,21 +5701,22 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
return noise, noisy_latents, timesteps, huber_c
|
||||
|
||||
|
||||
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
|
||||
):
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -5678,7 +6002,7 @@ def sample_images_common(
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
@@ -5712,11 +6036,13 @@ def sample_image_inference(
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.seed()
|
||||
|
||||
scheduler = get_my_scheduler(
|
||||
sample_sampler=sampler_name,
|
||||
@@ -5751,8 +6077,9 @@ def sample_image_inference(
|
||||
controlnet_image=controlnet_image,
|
||||
)
|
||||
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
@@ -5766,17 +6093,14 @@ def sample_image_inference(
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -10,6 +10,9 @@ from torchvision import transforms
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
@@ -82,6 +85,66 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
Convert a string to a torch.dtype
|
||||
|
||||
Args:
|
||||
s: string representation of the dtype
|
||||
default_dtype: default dtype to return if s is None
|
||||
|
||||
Returns:
|
||||
torch.dtype: the corresponding torch.dtype
|
||||
|
||||
Raises:
|
||||
ValueError: if the dtype is not supported
|
||||
|
||||
Examples:
|
||||
>>> str_to_dtype("float32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("fp32")
|
||||
torch.float32
|
||||
>>> str_to_dtype("float16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("fp16")
|
||||
torch.float16
|
||||
>>> str_to_dtype("bfloat16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("bf16")
|
||||
torch.bfloat16
|
||||
>>> str_to_dtype("fp8")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fn")
|
||||
torch.float8_e4m3fn
|
||||
>>> str_to_dtype("fp8_e4m3fnuz")
|
||||
torch.float8_e4m3fnuz
|
||||
>>> str_to_dtype("fp8_e5m2")
|
||||
torch.float8_e5m2
|
||||
>>> str_to_dtype("fp8_e5m2fnuz")
|
||||
torch.float8_e5m2fnuz
|
||||
"""
|
||||
if s is None:
|
||||
return default_dtype
|
||||
if s in ["bf16", "bfloat16"]:
|
||||
return torch.bfloat16
|
||||
elif s in ["fp16", "float16"]:
|
||||
return torch.float16
|
||||
elif s in ["fp32", "float32", "float"]:
|
||||
return torch.float32
|
||||
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||||
return torch.float8_e4m3fn
|
||||
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||||
return torch.float8_e4m3fnuz
|
||||
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||||
return torch.float8_e5m2fnuz
|
||||
elif s in ["fp8", "float8"]:
|
||||
return torch.float8_e4m3fn # default fp8
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
@@ -198,7 +261,7 @@ class MemoryEfficientSafeOpen:
|
||||
if tensor_bytes is None:
|
||||
byte_tensor = torch.empty(0, dtype=torch.uint8)
|
||||
else:
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
tensor_bytes = bytearray(tensor_bytes) # make it writable
|
||||
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
|
||||
|
||||
# process float8 types
|
||||
@@ -241,6 +304,24 @@ class MemoryEfficientSafeOpen:
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
||||
else:
|
||||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
||||
|
||||
return resized_cv2
|
||||
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
Reference in New Issue
Block a user