diff --git a/library/model_util.py b/library/model_util.py index be410a02..9918c7b2 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # rename or add position_ids + # remove position_ids for newer transformer, which causes error :( ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" if ANOTHER_POSITION_IDS_KEY in new_sd: # waifu diffusion v1.4 - position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] del new_sd[ANOTHER_POSITION_IDS_KEY] - else: - position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - new_sd["text_model.embeddings.position_ids"] = position_ids + if "text_model.embeddings.position_ids" in new_sd: + del new_sd["text_model.embeddings.position_ids"] + return new_sd diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index dc3887c3..7c5e6860 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -344,8 +344,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" - if args.v_parameterization: - logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") if args.clip_skip is not None: logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") diff --git a/train_network.py b/train_network.py index ab5483de..2d279b3b 100644 --- a/train_network.py +++ b/train_network.py @@ -1207,10 +1207,6 @@ class NetworkTrainer: args.max_train_steps > initial_step ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" - progress_bar = tqdm( - range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" - ) - epoch_to_start = 0 if initial_step > 0: if args.skip_until_initial_step: @@ -1309,6 +1305,10 @@ class NetworkTrainer: clean_memory_on_device(accelerator.device) + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + validation_steps = ( min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) )