mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into fast_image_sizes
This commit is contained in:
@@ -13,6 +13,7 @@ import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
@@ -44,7 +45,11 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers.optimization import (
|
||||
SchedulerType as DiffusersSchedulerType,
|
||||
TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION,
|
||||
)
|
||||
from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
DDPMScheduler,
|
||||
@@ -73,7 +78,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, pil_resize
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -656,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
def adjust_min_max_bucket_reso_by_steps(
|
||||
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
|
||||
) -> Tuple[int, int]:
|
||||
# make min/max bucket reso to be multiple of bucket_reso_steps
|
||||
if min_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
|
||||
)
|
||||
min_bucket_reso = adjusted_min_bucket_reso
|
||||
if max_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
|
||||
)
|
||||
max_bucket_reso = adjusted_max_bucket_reso
|
||||
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
|
||||
return min_bucket_reso, max_bucket_reso
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
@@ -988,9 +1021,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1011,20 +1061,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
@@ -1036,9 +1089,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
caching_strategy.cache_batch_latents(model, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -1049,9 +1101,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1072,28 +1141,31 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
r"""
|
||||
@@ -1663,12 +1735,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -1708,7 +1777,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
return [], [], []
|
||||
|
||||
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
|
||||
use_cached_info_for_subset = subset.cache_info
|
||||
@@ -2062,6 +2131,9 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -2284,9 +2356,7 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = cv2.resize(
|
||||
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
|
||||
)
|
||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -2425,7 +2495,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != reso: # HxW
|
||||
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
@@ -2534,7 +2604,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||
alpha_mask = example["alpha_masks"][j]
|
||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
||||
alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
|
||||
if os.name == "nt":
|
||||
cv2.imshow("alpha_mask", alpha_mask)
|
||||
|
||||
@@ -2680,7 +2750,10 @@ def trim_and_resize_if_required(
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
else:
|
||||
image = pil_resize(image, resized_size)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
@@ -3270,11 +3343,29 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
def int_or_float(value):
|
||||
if value.endswith("%"):
|
||||
try:
|
||||
return float(value[:-1]) / 100.0
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage")
|
||||
try:
|
||||
float_value = float(value)
|
||||
if float_value >= 1:
|
||||
return int(value)
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"'{value}' is not an int or float")
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, "
|
||||
"Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, "
|
||||
"DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, "
|
||||
"AdaFactor. "
|
||||
"Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -3305,6 +3396,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--optimizer_schedulefree_wrapper",
|
||||
# action="store_true",
|
||||
# help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
|
||||
# )
|
||||
|
||||
# parser.add_argument(
|
||||
# "--schedulefree_wrapper_args",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# nargs="*",
|
||||
# help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
|
||||
# )
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_args",
|
||||
@@ -3322,9 +3427,17 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)",
|
||||
help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps"
|
||||
" / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_decay_steps",
|
||||
type=int_or_float,
|
||||
default=0,
|
||||
help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps"
|
||||
" / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_num_cycles",
|
||||
@@ -3344,6 +3457,20 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
|
||||
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_timescale",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`"
|
||||
+ " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_min_lr_ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler"
|
||||
+ " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効",
|
||||
)
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
@@ -4071,8 +4198,20 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument(
|
||||
"--min_bucket_reso",
|
||||
type=int,
|
||||
default=256,
|
||||
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_bucket_reso",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
@@ -4290,7 +4429,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
@@ -4343,6 +4482,7 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
lr = args.learning_rate
|
||||
optimizer = None
|
||||
optimizer_class = None
|
||||
|
||||
if optimizer_type == "Lion".lower():
|
||||
try:
|
||||
@@ -4400,7 +4540,8 @@ def get_optimizer(args, trainable_params):
|
||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
if optimizer_class is not None:
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW".lower():
|
||||
logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
@@ -4562,26 +4703,159 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
if optimizer_type == "AdamWScheduleFree".lower():
|
||||
optimizer_class = sf.AdamWScheduleFree
|
||||
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "SGDScheduleFree".lower():
|
||||
optimizer_class = sf.SGDScheduleFree
|
||||
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# make optimizer as train mode: we don't need to call train again, because eval will not be called in training loop
|
||||
optimizer.train()
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
case_sensitive_optimizer_type = args.optimizer_type # not lower
|
||||
logger.info(f"use {case_sensitive_optimizer_type} | {optimizer_kwargs}")
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
if "." not in case_sensitive_optimizer_type: # from torch.optim
|
||||
optimizer_module = torch.optim
|
||||
else: # from other library
|
||||
values = case_sensitive_optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
case_sensitive_optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, case_sensitive_optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
"""
|
||||
# wrap any of above optimizer with schedulefree, if optimizer is not schedulefree
|
||||
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
schedulefree_wrapper_kwargs = {}
|
||||
if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
|
||||
for arg in args.schedulefree_wrapper_args:
|
||||
key, value = arg.split("=")
|
||||
value = ast.literal_eval(value)
|
||||
schedulefree_wrapper_kwargs[key] = value
|
||||
|
||||
sf_wrapper = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
|
||||
sf_wrapper.train() # make optimizer as train mode
|
||||
|
||||
# we need to make optimizer as a subclass of torch.optim.Optimizer, we make another Proxy class over SFWrapper
|
||||
class OptimizerProxy(torch.optim.Optimizer):
|
||||
def __init__(self, sf_wrapper):
|
||||
self._sf_wrapper = sf_wrapper
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._sf_wrapper, name)
|
||||
|
||||
# override properties
|
||||
@property
|
||||
def state(self):
|
||||
return self._sf_wrapper.state
|
||||
|
||||
@state.setter
|
||||
def state(self, state):
|
||||
self._sf_wrapper.state = state
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
return self._sf_wrapper.param_groups
|
||||
|
||||
@param_groups.setter
|
||||
def param_groups(self, param_groups):
|
||||
self._sf_wrapper.param_groups = param_groups
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
return self._sf_wrapper.defaults
|
||||
|
||||
@defaults.setter
|
||||
def defaults(self, defaults):
|
||||
self._sf_wrapper.defaults = defaults
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self._sf_wrapper.add_param_group(param_group)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._sf_wrapper.load_state_dict(state_dict)
|
||||
|
||||
def state_dict(self):
|
||||
return self._sf_wrapper.state_dict()
|
||||
|
||||
def zero_grad(self):
|
||||
self._sf_wrapper.zero_grad()
|
||||
|
||||
def step(self, closure=None):
|
||||
self._sf_wrapper.step(closure)
|
||||
|
||||
def train(self):
|
||||
self._sf_wrapper.train()
|
||||
|
||||
def eval(self):
|
||||
self._sf_wrapper.eval()
|
||||
|
||||
# isinstance チェックをパスするためのメソッド
|
||||
def __instancecheck__(self, instance):
|
||||
return isinstance(instance, (type(self), Optimizer))
|
||||
|
||||
optimizer = OptimizerProxy(sf_wrapper)
|
||||
|
||||
logger.info(f"wrap optimizer with ScheduleFreeWrapper | {schedulefree_wrapper_kwargs}")
|
||||
"""
|
||||
|
||||
# for logging
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
|
||||
if not is_schedulefree_optimizer(optimizer, args):
|
||||
# return dummy func
|
||||
return lambda: None, lambda: None
|
||||
|
||||
# get train and eval functions from optimizer
|
||||
train_fn = optimizer.train
|
||||
eval_fn = optimizer.eval
|
||||
|
||||
return train_fn, eval_fn
|
||||
|
||||
|
||||
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||
|
||||
|
||||
def get_dummy_scheduler(optimizer: Optimizer) -> Any:
|
||||
# dummy scheduler for schedulefree optimizer. supports only empty step(), get_last_lr() and optimizers.
|
||||
# this scheduler is used for logging only.
|
||||
# this isn't be wrapped by accelerator because of this class is not a subclass of torch.optim.lr_scheduler._LRScheduler
|
||||
class DummyScheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def get_last_lr(self):
|
||||
return [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
return DummyScheduler(optimizer)
|
||||
|
||||
|
||||
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||
# Add some checking and features to the original function.
|
||||
|
||||
@@ -4590,11 +4864,23 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
# if schedulefree optimizer, return dummy scheduler
|
||||
if is_schedulefree_optimizer(optimizer, args):
|
||||
return get_dummy_scheduler(optimizer)
|
||||
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
num_warmup_steps: Optional[int] = (
|
||||
int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps
|
||||
)
|
||||
num_decay_steps: Optional[int] = (
|
||||
int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps
|
||||
)
|
||||
num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
|
||||
num_cycles = args.lr_scheduler_num_cycles
|
||||
power = args.lr_scheduler_power
|
||||
timescale = args.lr_scheduler_timescale
|
||||
min_lr_ratio = args.lr_scheduler_min_lr_ratio
|
||||
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
@@ -4630,15 +4916,17 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
# logger.info(f"adafactor scheduler init lr {initial_lr}")
|
||||
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
||||
|
||||
if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value:
|
||||
name = DiffusersSchedulerType(name)
|
||||
schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||
|
||||
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
@@ -4646,6 +4934,9 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
||||
|
||||
if name == SchedulerType.INVERSE_SQRT:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
@@ -4664,7 +4955,46 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
|
||||
if name == SchedulerType.COSINE_WITH_MIN_LR:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_rate=min_lr_ratio,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# these schedulers do not require `num_decay_steps`
|
||||
if name == SchedulerType.LINEAR or name == SchedulerType.COSINE:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# All other schedulers require `num_decay_steps`
|
||||
if num_decay_steps is None:
|
||||
raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.")
|
||||
if name == SchedulerType.WARMUP_STABLE_DECAY:
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_stable_steps=num_stable_steps,
|
||||
num_decay_steps=num_decay_steps,
|
||||
num_cycles=num_cycles / 2,
|
||||
min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
return schedule_func(
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_decay_steps=num_decay_steps,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
@@ -5312,34 +5642,27 @@ def save_sd_model_on_train_end_common(
|
||||
|
||||
|
||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||
|
||||
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
|
||||
# as. In the future there may be a smarter way
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
|
||||
timestep = timesteps.item()
|
||||
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
huber_c = math.exp(-alpha * timestep)
|
||||
huber_c = torch.exp(-alpha * timesteps)
|
||||
elif args.huber_schedule == "snr":
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
|
||||
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
|
||||
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
|
||||
elif args.huber_schedule == "constant":
|
||||
huber_c = args.huber_c
|
||||
huber_c = torch.full((b_size,), args.huber_c)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
|
||||
|
||||
timesteps = timesteps.repeat(b_size).to(device)
|
||||
huber_c = huber_c.to(device)
|
||||
elif args.loss_type == "l2":
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||
huber_c = 1 # may be anything, as it's not used
|
||||
huber_c = None # may be anything, as it's not used
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
|
||||
timesteps = timesteps.long()
|
||||
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps, huber_c
|
||||
|
||||
|
||||
@@ -5378,21 +5701,22 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
return noise, noisy_latents, timesteps, huber_c
|
||||
|
||||
|
||||
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
|
||||
):
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -5678,7 +6002,7 @@ def sample_images_common(
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
@@ -5712,11 +6036,13 @@ def sample_image_inference(
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.seed()
|
||||
|
||||
scheduler = get_my_scheduler(
|
||||
sample_sampler=sampler_name,
|
||||
@@ -5751,8 +6077,9 @@ def sample_image_inference(
|
||||
controlnet_image=controlnet_image,
|
||||
)
|
||||
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
@@ -5766,17 +6093,14 @@ def sample_image_inference(
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user