mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'main' into val
This commit is contained in:
@@ -85,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
is_reg: bool = False
|
||||
class_tokens: Optional[str] = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -96,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams):
|
||||
class ControlNetSubsetParams(BaseSubsetParams):
|
||||
conditioning_data_dir: str = None
|
||||
caption_extension: str = ".caption"
|
||||
cache_info: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -205,6 +207,7 @@ class ConfigSanitizer:
|
||||
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"class_tokens": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
DB_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
@@ -217,6 +220,7 @@ class ConfigSanitizer:
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
"cache_info": bool,
|
||||
}
|
||||
CN_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
@@ -326,7 +330,10 @@ class ConfigSanitizer:
|
||||
|
||||
self.dataset_schema = validate_flex_dataset
|
||||
elif support_dreambooth:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
if support_controlnet:
|
||||
self.dataset_schema = self.cn_dataset_schema
|
||||
else:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
elif support_finetuning:
|
||||
self.dataset_schema = self.ft_dataset_schema
|
||||
elif support_controlnet:
|
||||
@@ -578,7 +585,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f'{info}')
|
||||
logger.info(f"{info}")
|
||||
|
||||
# print validation info
|
||||
info = ""
|
||||
@@ -662,7 +669,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
|
||||
@@ -3,11 +3,14 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
@@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
||||
else:
|
||||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
@@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1/torch.sqrt(snr_t)
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
@@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
139
library/deepspeed_utils.py
Normal file
139
library/deepspeed_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_init_flag",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||
"Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_save_16bit_model",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_master_weights_and_gradients",
|
||||
action="store_true",
|
||||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
||||
)
|
||||
|
||||
|
||||
def prepare_deepspeed_args(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return
|
||||
|
||||
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
args.max_data_loader_n_workers = 1
|
||||
|
||||
|
||||
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return None
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
zero_stage=args.zero_stage,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_clipping=args.max_grad_norm,
|
||||
offload_optimizer_device=args.offload_optimizer_device,
|
||||
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||
offload_param_device=args.offload_param_device,
|
||||
offload_param_nvme_path=args.offload_param_nvme_path,
|
||||
zero3_init_flag=args.zero3_init_flag,
|
||||
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
||||
logger.info("[DeepSpeed] full fp16 enable.")
|
||||
else:
|
||||
logger.info(
|
||||
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
||||
)
|
||||
|
||||
if args.offload_optimizer_device is not None:
|
||||
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||
logger.info("[DeepSpeed] building cpu_adam done.")
|
||||
|
||||
return deepspeed_plugin
|
||||
|
||||
|
||||
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
||||
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
# remove None from models
|
||||
models = {k: v for k, v in models.items() if v is not None}
|
||||
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
@@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
@@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 2024
|
||||
ipex._C._DeviceProperties.minor = 0
|
||||
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
|
||||
@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
torch.xpu.synchronize(input.device)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
torch.xpu.synchronize(input.device)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
|
||||
# Slice SDPA
|
||||
@@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
@@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
@@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
torch.xpu.synchronize(query.device)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
if antialias or align_corners is not None or mode == 'bicubic':
|
||||
return_device = tensor.device
|
||||
return_dtype = tensor.dtype
|
||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||
@@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
else:
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
||||
@wraps(torch.Tensor.pin_memory)
|
||||
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
||||
if device is None:
|
||||
device = "xpu"
|
||||
if check_device(device):
|
||||
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
||||
else:
|
||||
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
@@ -216,7 +226,9 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
if dtype == bytes:
|
||||
dtype = None
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
@@ -256,11 +268,13 @@ def torch_Generator(device=None):
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
if map_location is None:
|
||||
map_location = "xpu"
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
@@ -268,6 +282,7 @@ def ipex_hijacks():
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.empty = torch_empty
|
||||
|
||||
@@ -31,8 +31,10 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IN_CHANNELS: int = 4
|
||||
@@ -1074,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
@@ -1132,7 +1134,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
@@ -1164,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = _self.time_embed(t_emb)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
|
||||
@@ -63,12 +63,14 @@ from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import imagesize
|
||||
import cv2
|
||||
import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -423,6 +425,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
is_reg: bool,
|
||||
class_tokens: Optional[str],
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
@@ -471,6 +474,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DreamBoothSubset):
|
||||
@@ -540,6 +544,7 @@ class ControlNetSubset(BaseSubset):
|
||||
image_dir: str,
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
cache_info: bool,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
@@ -587,6 +592,7 @@ class ControlNetSubset(BaseSubset):
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ControlNetSubset):
|
||||
@@ -707,6 +713,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# if caption is multiline, random choice one line
|
||||
if "\n" in caption:
|
||||
caption = random.choice(caption.split("\n"))
|
||||
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
@@ -725,6 +735,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
else:
|
||||
# if caption is multiline, use the first line
|
||||
caption = caption.split("\n")[0]
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
@@ -1087,8 +1100,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
return image.size
|
||||
return imagesize.get(image_path)
|
||||
|
||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||
img = load_image(image_path)
|
||||
@@ -1417,6 +1429,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
@@ -1466,7 +1480,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1485,7 +1499,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
@@ -1505,17 +1522,54 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
|
||||
use_cached_info_for_subset = subset.cache_info
|
||||
if use_cached_info_for_subset:
|
||||
logger.info(
|
||||
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
|
||||
)
|
||||
if not os.path.isfile(info_cache_file):
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
|
||||
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
use_cached_info_for_subset = False
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
|
||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||
metas = json.load(f)
|
||||
img_paths = list(metas.keys())
|
||||
sizes = [meta["resolution"] for meta in metas.values()]
|
||||
|
||||
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
|
||||
else:
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
sizes = [None] * len(img_paths)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
@@ -1532,12 +1586,24 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
logger.warning(missing_caption)
|
||||
return img_paths, captions
|
||||
|
||||
if not use_cached_info_for_subset and subset.cache_info:
|
||||
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
|
||||
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
|
||||
matas = {}
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
matas[img_path] = {"caption": caption, "resolution": list(size)}
|
||||
with open(info_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(matas, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
|
||||
|
||||
# if sizes are not set, image size will be read in make_buckets
|
||||
return img_paths, captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[ImageInfo] = []
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1551,7 +1617,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, captions = load_dreambooth_dir(subset)
|
||||
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||
@@ -1563,10 +1629,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
reg_infos.append(info)
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
@@ -1587,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info in reg_infos:
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
n += info.num_repeats
|
||||
@@ -1679,10 +1747,24 @@ class FineTuningDataset(BaseDataset):
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
if caption is None:
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ", " + tags
|
||||
tags_list.append(tags)
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
@@ -1835,11 +1917,15 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
assert (
|
||||
not subset.random_crop
|
||||
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
|
||||
db_subset = DreamBoothSubset(
|
||||
subset.image_dir,
|
||||
False,
|
||||
None,
|
||||
subset.caption_extension,
|
||||
subset.cache_info,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
@@ -1891,7 +1977,7 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
cond_imgs_with_img = set()
|
||||
cond_imgs_with_pair = set()
|
||||
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
|
||||
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
|
||||
subset = None
|
||||
@@ -1905,20 +1991,22 @@ class ControlNetDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.conditioning_data_dir}")
|
||||
continue
|
||||
|
||||
img_basename = os.path.basename(info.absolute_path)
|
||||
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
|
||||
if not os.path.exists(ctrl_img_path):
|
||||
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
|
||||
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
|
||||
if len(ctrl_img_path) < 1:
|
||||
missing_imgs.append(img_basename)
|
||||
continue
|
||||
ctrl_img_path = ctrl_img_path[0]
|
||||
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
|
||||
|
||||
info.cond_img_path = ctrl_img_path
|
||||
cond_imgs_with_img.add(ctrl_img_path)
|
||||
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
|
||||
|
||||
extra_imgs = []
|
||||
for subset in subsets:
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
extra_imgs.extend(
|
||||
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
|
||||
)
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
#assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
#assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
@@ -2967,7 +3055,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--save_state",
|
||||
action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
|
||||
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_state_on_train_end",
|
||||
action="store_true",
|
||||
help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する",
|
||||
)
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
|
||||
@@ -3050,6 +3143,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -3112,12 +3206,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_offset_random_strength",
|
||||
action="store_true",
|
||||
help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multires_noise_iterations",
|
||||
type=int,
|
||||
@@ -3131,6 +3231,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
|
||||
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ip_noise_gamma_random_strength",
|
||||
action="store_true",
|
||||
help="Use random strength between 0~ip_noise_gamma for input perturbation noise."
|
||||
+ "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--perlin_noise",
|
||||
# type=int,
|
||||
@@ -3166,6 +3272,27 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loss_type",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "huber", "smooth_l1"],
|
||||
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huber_schedule",
|
||||
type=str,
|
||||
default="snr",
|
||||
choices=["constant", "exponential", "snr"],
|
||||
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"
|
||||
+ " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huber_c",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lowram",
|
||||
@@ -3274,6 +3401,74 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
|
||||
|
||||
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masked_loss",
|
||||
action="store_true",
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
|
||||
# verify command line args for training
|
||||
def verify_command_line_training_args(args: argparse.Namespace):
|
||||
# if wandb is enabled, the command line is exposed to the public
|
||||
# check whether sensitive options are included in the command line arguments
|
||||
# if so, warn or inform the user to move them to the configuration file
|
||||
# wandbが有効な場合、コマンドラインが公開される
|
||||
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
|
||||
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する
|
||||
|
||||
wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
|
||||
if not wandb_enabled:
|
||||
return
|
||||
|
||||
sensitive_args = ["wandb_api_key", "huggingface_token"]
|
||||
sensitive_path_args = [
|
||||
"pretrained_model_name_or_path",
|
||||
"vae",
|
||||
"tokenizer_cache_dir",
|
||||
"train_data_dir",
|
||||
"conditioning_data_dir",
|
||||
"reg_data_dir",
|
||||
"output_dir",
|
||||
"logging_dir",
|
||||
]
|
||||
|
||||
for arg in sensitive_args:
|
||||
if getattr(args, arg, None) is not None:
|
||||
logger.warning(
|
||||
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
|
||||
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
|
||||
)
|
||||
|
||||
# if path is absolute, it may include sensitive information
|
||||
for arg in sensitive_path_args:
|
||||
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
|
||||
logger.info(
|
||||
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
|
||||
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
|
||||
)
|
||||
|
||||
if getattr(args, "config_file", None) is not None:
|
||||
logger.info(
|
||||
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
|
||||
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
|
||||
)
|
||||
|
||||
# other sensitive options
|
||||
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
|
||||
logger.info(
|
||||
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
|
||||
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
|
||||
)
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
@@ -3329,6 +3524,18 @@ def verify_training_args(args: argparse.Namespace):
|
||||
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
|
||||
)
|
||||
|
||||
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
args.sample_every_n_epochs = None
|
||||
|
||||
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
|
||||
logger.warning(
|
||||
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
|
||||
)
|
||||
args.sample_every_n_steps = None
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -3337,6 +3544,12 @@ def add_dataset_arguments(
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_info",
|
||||
action="store_true",
|
||||
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
|
||||
+ " / メタ情報(キャプションとサイズ)をキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
|
||||
)
|
||||
@@ -3565,7 +3778,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
exit(1)
|
||||
|
||||
logger.info(f"Loading settings from {config_path}...")
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_dict = toml.load(f)
|
||||
|
||||
# combine all sections into one
|
||||
@@ -4074,6 +4287,10 @@ def load_tokenizer(args: argparse.Namespace):
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
"""
|
||||
|
||||
if args.logging_dir is None:
|
||||
logging_dir = None
|
||||
else:
|
||||
@@ -4119,6 +4336,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
),
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -4126,6 +4345,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
@@ -4196,7 +4416,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
@@ -4207,7 +4426,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
@@ -4216,7 +4434,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
@@ -4683,11 +4900,47 @@ def save_sd_model_on_train_end_common(
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||
|
||||
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
|
||||
# as. In the future there may be a smarter way
|
||||
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
|
||||
timestep = timesteps.item()
|
||||
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
huber_c = math.exp(-alpha * timestep)
|
||||
elif args.huber_schedule == "snr":
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
elif args.huber_schedule == "constant":
|
||||
huber_c = args.huber_c
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
timesteps = timesteps.repeat(b_size).to(device)
|
||||
elif args.loss_type == "l2":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||
huber_c = 1 # may be anything, as it's not used
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
||||
timesteps = timesteps.long()
|
||||
|
||||
return timesteps, huber_c
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
if args.noise_offset_random_strength:
|
||||
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
|
||||
else:
|
||||
noise_offset = args.noise_offset
|
||||
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
|
||||
if args.multires_noise_iterations:
|
||||
noise = custom_train_functions.pyramid_noise_like(
|
||||
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
||||
@@ -4698,17 +4951,44 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
|
||||
else:
|
||||
strength = args.ip_noise_gamma
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
return noise, noisy_latents, timesteps
|
||||
return noise, noisy_latents, timesteps, huber_c
|
||||
|
||||
|
||||
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||
):
|
||||
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
|
||||
return loss
|
||||
|
||||
|
||||
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||
|
||||
Reference in New Issue
Block a user