From 447c56bf505c2a84d00e88ac173a1b6961894429 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 09:53:14 +0900 Subject: [PATCH] =?UTF-8?q?typo=E4=BF=AE=E6=AD=A3=E3=80=81step=E3=82=92glo?= =?UTF-8?q?bal=5Fstep=E3=81=AB=E4=BF=AE=E6=AD=A3=E3=80=81=E3=83=90?= =?UTF-8?q?=E3=82=B0=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 2 +- library/config_util.py | 4 ++-- library/train_util.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 473a13ec..def942fa 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -265,8 +265,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index 98d89b7e..84bbf308 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -57,7 +57,7 @@ class BaseSubsetParams: caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 - token_warmup_step: Union[float,int] = 0 + token_warmup_step: float = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -140,7 +140,7 @@ class ConfigSanitizer: "shuffle_caption": bool, "keep_tokens": int, "token_warmup_min": int, - "token_warmup_step": Union[float,int], + "token_warmup_step": Any(float,int), } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { diff --git a/library/train_util.py b/library/train_util.py index 52b51314..83e9372b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2046,7 +2046,7 @@ def add_dataset_arguments( ) parser.add_argument( - "--token_warmup_steps", + "--token_warmup_step", type=float, default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", diff --git a/train_db.py b/train_db.py index 164e354e..e17a8b79 100644 --- a/train_db.py +++ b/train_db.py @@ -241,6 +241,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") @@ -249,7 +250,6 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 16f41ebb..6d23ab07 100644 --- a/train_network.py +++ b/train_network.py @@ -507,8 +507,8 @@ def train(args): network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b3467d94..42746169 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -340,8 +340,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device)