mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
7 Commits
dc59e84aea
...
7c4f5f78fd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c4f5f78fd | ||
|
|
63711390a0 | ||
|
|
206adb6438 | ||
|
|
60bfa97b19 | ||
|
|
51e1b45abd | ||
|
|
a2c0f3644b | ||
|
|
96f06d917e |
@@ -327,15 +327,18 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
)
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||
if support_text_encoder_caching:
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs / text encoderの出力をキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
@@ -343,7 +346,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
|
||||
if args.clip_skip is not None:
|
||||
@@ -366,7 +369,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if support_text_encoder_caching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
logger.warning(
|
||||
|
||||
@@ -3211,6 +3211,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_steps_after_x", type=int, default=0, help="save checkpoint every N steps only after X steps / N ステップごとにチェックポイントを保存しますが、X ステップ後にのみ保存します"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_n_epoch_ratio",
|
||||
type=int,
|
||||
|
||||
@@ -775,7 +775,7 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
|
||||
@@ -500,7 +500,7 @@ def train(args):
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -460,7 +460,7 @@ def train(args):
|
||||
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -5,6 +5,7 @@ import regex
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
@@ -19,8 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
@@ -122,8 +123,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_textual_inversion.setup_parser()
|
||||
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -521,7 +521,7 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -399,7 +399,7 @@ def train(args):
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
|
||||
@@ -1037,7 +1037,8 @@ class NetworkTrainer:
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
|
||||
@@ -646,7 +646,7 @@ class TextualInversionTrainer:
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs_list = []
|
||||
|
||||
@@ -515,7 +515,7 @@ def train(args):
|
||||
# )
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0 and global_step >= args.save_every_n_steps_after_x:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = (
|
||||
|
||||
Reference in New Issue
Block a user