mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge pull request #750 from kohya-ss/dev
block lr for U-Net with SDXL etc.
This commit is contained in:
@@ -52,6 +52,10 @@ def main(args):
|
|||||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||||
if args.bucket_reso_steps % 8 > 0:
|
if args.bucket_reso_steps % 8 > 0:
|
||||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||||
|
if args.bucket_reso_steps % 32 > 0:
|
||||||
|
print(
|
||||||
|
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||||
|
)
|
||||||
|
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeigh
|
|||||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
|
|
||||||
DEFAULT_NOISE_OFFSET = 0.0357
|
# DEFAULT_NOISE_OFFSET = 0.0357
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||||
@@ -312,18 +312,18 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
|||||||
if args.clip_skip is not None:
|
if args.clip_skip is not None:
|
||||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||||
|
|
||||||
if args.multires_noise_iterations:
|
# if args.multires_noise_iterations:
|
||||||
print(
|
# print(
|
||||||
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
if args.noise_offset is None:
|
# if args.noise_offset is None:
|
||||||
args.noise_offset = DEFAULT_NOISE_OFFSET
|
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||||
elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||||
print(
|
# print(
|
||||||
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||||
)
|
# )
|
||||||
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
|
|||||||
@@ -800,6 +800,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
random.shuffle(self.buckets_indices)
|
random.shuffle(self.buckets_indices)
|
||||||
self.bucket_manager.shuffle()
|
self.bucket_manager.shuffle()
|
||||||
|
|
||||||
|
def verify_bucket_reso_steps(self, min_steps: int):
|
||||||
|
assert self.bucket_reso_steps is None or self.bucket_reso_steps % min_steps == 0, (
|
||||||
|
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
|
||||||
|
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
|
||||||
|
)
|
||||||
|
|
||||||
def is_latent_cacheable(self):
|
def is_latent_cacheable(self):
|
||||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||||
|
|
||||||
@@ -1831,6 +1837,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.set_caching_mode(caching_mode)
|
dataset.set_caching_mode(caching_mode)
|
||||||
|
|
||||||
|
def verify_bucket_reso_steps(self, min_steps: int):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.verify_bucket_reso_steps(min_steps)
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
@@ -2020,6 +2030,9 @@ class MinimalDataset(BaseDataset):
|
|||||||
self.is_reg = False
|
self.is_reg = False
|
||||||
self.image_dir = "dummy" # for metadata
|
self.image_dir = "dummy" # for metadata
|
||||||
|
|
||||||
|
def verify_bucket_reso_steps(self, min_steps: int):
|
||||||
|
pass
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -2981,11 +2994,11 @@ def verify_training_args(args: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||||
# Listを使って数えてもいいけど並べてしまえ
|
# # Listを使って数えてもいいけど並べてしまえ
|
||||||
if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
# if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
|
# "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
|
||||||
)
|
# )
|
||||||
# if args.noise_offset is not None and args.perlin_noise is not None:
|
# if args.noise_offset is not None and args.perlin_noise is not None:
|
||||||
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
|
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
|
||||||
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
|
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
|
||||||
@@ -4268,7 +4281,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||||
elif args.multires_noise_iterations:
|
if args.multires_noise_iterations:
|
||||||
noise = custom_train_functions.pyramid_noise_like(
|
noise = custom_train_functions.pyramid_noise_like(
|
||||||
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1309,7 +1309,10 @@ def main(args):
|
|||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
|
has_steps_offset = True
|
||||||
|
has_clip_sample = True
|
||||||
scheduler_num_noises_per_step = 1
|
scheduler_num_noises_per_step = 1
|
||||||
|
|
||||||
if args.sampler == "ddim":
|
if args.sampler == "ddim":
|
||||||
scheduler_cls = DDIMScheduler
|
scheduler_cls = DDIMScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_ddim
|
scheduler_module = diffusers.schedulers.scheduling_ddim
|
||||||
@@ -1319,32 +1322,48 @@ def main(args):
|
|||||||
elif args.sampler == "pndm":
|
elif args.sampler == "pndm":
|
||||||
scheduler_cls = PNDMScheduler
|
scheduler_cls = PNDMScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_pndm
|
scheduler_module = diffusers.schedulers.scheduling_pndm
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "lms" or args.sampler == "k_lms":
|
elif args.sampler == "lms" or args.sampler == "k_lms":
|
||||||
scheduler_cls = LMSDiscreteScheduler
|
scheduler_cls = LMSDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
|
scheduler_module = diffusers.schedulers.scheduling_lms_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "euler" or args.sampler == "k_euler":
|
elif args.sampler == "euler" or args.sampler == "k_euler":
|
||||||
scheduler_cls = EulerDiscreteScheduler
|
scheduler_cls = EulerDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
||||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
||||||
scheduler_cls = DPMSolverMultistepScheduler
|
scheduler_cls = DPMSolverMultistepScheduler
|
||||||
sched_init_args["algorithm_type"] = args.sampler
|
sched_init_args["algorithm_type"] = args.sampler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
|
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpmsingle":
|
elif args.sampler == "dpmsingle":
|
||||||
scheduler_cls = DPMSolverSinglestepScheduler
|
scheduler_cls = DPMSolverSinglestepScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
|
scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep
|
||||||
|
has_clip_sample = False
|
||||||
|
has_steps_offset = False
|
||||||
elif args.sampler == "heun":
|
elif args.sampler == "heun":
|
||||||
scheduler_cls = HeunDiscreteScheduler
|
scheduler_cls = HeunDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
|
scheduler_module = diffusers.schedulers.scheduling_heun_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
|
elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2":
|
||||||
scheduler_cls = KDPM2DiscreteScheduler
|
scheduler_cls = KDPM2DiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
|
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete
|
||||||
|
has_clip_sample = False
|
||||||
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
|
elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a":
|
||||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||||
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
|
scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete
|
||||||
scheduler_num_noises_per_step = 2
|
scheduler_num_noises_per_step = 2
|
||||||
|
has_clip_sample = False
|
||||||
|
|
||||||
|
# 警告を出さないようにする
|
||||||
|
if has_steps_offset:
|
||||||
|
sched_init_args["steps_offset"] = 1
|
||||||
|
if has_clip_sample:
|
||||||
|
sched_init_args["clip_sample"] = False
|
||||||
|
|
||||||
# samplerの乱数をあらかじめ指定するための処理
|
# samplerの乱数をあらかじめ指定するための処理
|
||||||
|
|
||||||
@@ -1397,10 +1416,11 @@ def main(args):
|
|||||||
**sched_init_args,
|
**sched_init_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# clip_sample=Trueにする
|
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
# # clip_sample=Trueにする
|
||||||
print("set clip_sample to True")
|
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||||
scheduler.config.clip_sample = True
|
# print("set clip_sample to True")
|
||||||
|
# scheduler.config.clip_sample = True
|
||||||
|
|
||||||
# deviceを決定する
|
# deviceを決定する
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||||
|
|||||||
127
sdxl_train.py
127
sdxl_train.py
@@ -5,6 +5,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
from typing import List
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -30,6 +31,67 @@ from library.custom_train_functions import (
|
|||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
|
|
||||||
|
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
|
||||||
|
|
||||||
|
|
||||||
|
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
|
||||||
|
block_params = [[] for _ in range(len(block_lrs))]
|
||||||
|
|
||||||
|
for i, (name, param) in enumerate(unet.named_parameters()):
|
||||||
|
if name.startswith("time_embed.") or name.startswith("label_emb."):
|
||||||
|
block_index = 0 # 0
|
||||||
|
elif name.startswith("input_blocks."): # 1-9
|
||||||
|
block_index = 1 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("middle_block."): # 10-12
|
||||||
|
block_index = 10 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("output_blocks."): # 13-21
|
||||||
|
block_index = 13 + int(name.split(".")[1])
|
||||||
|
elif name.startswith("out."): # 22
|
||||||
|
block_index = 22
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected parameter name: {name}")
|
||||||
|
|
||||||
|
block_params[block_index].append(param)
|
||||||
|
|
||||||
|
params_to_optimize = []
|
||||||
|
for i, params in enumerate(block_params):
|
||||||
|
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
|
||||||
|
continue
|
||||||
|
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
|
||||||
|
|
||||||
|
return params_to_optimize
|
||||||
|
|
||||||
|
|
||||||
|
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
||||||
|
lrs = lr_scheduler.get_last_lr()
|
||||||
|
|
||||||
|
lr_index = 0
|
||||||
|
block_index = 0
|
||||||
|
while lr_index < len(lrs):
|
||||||
|
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||||
|
name = f"block{block_index}"
|
||||||
|
if block_lrs[block_index] == 0:
|
||||||
|
block_index += 1
|
||||||
|
continue
|
||||||
|
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||||
|
name = "text_encoder1"
|
||||||
|
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
|
||||||
|
name = "text_encoder2"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unexpected block_index: {block_index}")
|
||||||
|
|
||||||
|
block_index += 1
|
||||||
|
|
||||||
|
logs["lr/" + name] = float(lrs[lr_index])
|
||||||
|
|
||||||
|
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||||
|
logs["lr/d*lr/" + name] = (
|
||||||
|
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_index += 1
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -40,6 +102,14 @@ def train(args):
|
|||||||
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||||
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||||
|
|
||||||
|
if args.block_lr:
|
||||||
|
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
||||||
|
assert (
|
||||||
|
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
|
||||||
|
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
|
||||||
|
else:
|
||||||
|
block_lrs = None
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
use_dreambooth_method = args.in_json is None
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
@@ -98,6 +168,8 @@ def train(args):
|
|||||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(32)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group, True)
|
train_util.debug_dataset(train_dataset_group, True)
|
||||||
return
|
return
|
||||||
@@ -233,15 +305,28 @@ def train(args):
|
|||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
params = []
|
|
||||||
for m in training_models:
|
|
||||||
params.extend(m.parameters())
|
|
||||||
params_to_optimize = params
|
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
if block_lrs is None:
|
||||||
n_params = 0
|
params = []
|
||||||
for p in params:
|
for m in training_models:
|
||||||
n_params += p.numel()
|
params.extend(m.parameters())
|
||||||
|
params_to_optimize = params
|
||||||
|
|
||||||
|
# calculate number of trainable parameters
|
||||||
|
n_params = 0
|
||||||
|
for p in params:
|
||||||
|
n_params += p.numel()
|
||||||
|
else:
|
||||||
|
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||||
|
for m in training_models[1:]: # Text Encoders if exists
|
||||||
|
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
|
||||||
|
|
||||||
|
# calculate number of trainable parameters
|
||||||
|
n_params = 0
|
||||||
|
for params in params_to_optimize:
|
||||||
|
for p in params["params"]:
|
||||||
|
n_params += p.numel()
|
||||||
|
|
||||||
accelerator.print(f"number of models: {len(training_models)}")
|
accelerator.print(f"number of models: {len(training_models)}")
|
||||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||||
|
|
||||||
@@ -526,13 +611,18 @@ def train(args):
|
|||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
logs = {"loss": current_loss}
|
||||||
if (
|
if block_lrs is None:
|
||||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy"
|
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
|
||||||
): # tracking d*lr value
|
if (
|
||||||
logs["lr/d*lr"] = (
|
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
): # tracking d*lr value
|
||||||
)
|
logs["lr/d*lr"] = (
|
||||||
|
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
|
||||||
|
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
# TODO moving averageにする
|
# TODO moving averageにする
|
||||||
@@ -636,6 +726,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block_lr",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
||||||
|
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(32)
|
||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
(
|
(
|
||||||
load_stable_diffusion_format,
|
load_stable_diffusion_format,
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
super().assert_extra_args(args, train_dataset_group)
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(32)
|
||||||
|
|
||||||
def load_target_model(self, args, weight_dtype, accelerator):
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
(
|
(
|
||||||
load_stable_diffusion_format,
|
load_stable_diffusion_format,
|
||||||
|
|||||||
Reference in New Issue
Block a user