mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into val-loss-improvement
This commit is contained in:
@@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
|||||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||||
|
|
||||||
# rename or add position_ids
|
# remove position_ids for newer transformer, which causes error :(
|
||||||
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
||||||
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
||||||
# waifu diffusion v1.4
|
# waifu diffusion v1.4
|
||||||
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
|
||||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||||
else:
|
|
||||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
|
||||||
|
|
||||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
if "text_model.embeddings.position_ids" in new_sd:
|
||||||
|
del new_sd["text_model.embeddings.position_ids"]
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -344,8 +344,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
|
|||||||
|
|
||||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||||
if args.v_parameterization:
|
|
||||||
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
|
||||||
|
|
||||||
if args.clip_skip is not None:
|
if args.clip_skip is not None:
|
||||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||||
|
|||||||
@@ -1207,10 +1207,6 @@ class NetworkTrainer:
|
|||||||
args.max_train_steps > initial_step
|
args.max_train_steps > initial_step
|
||||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||||
|
|
||||||
progress_bar = tqdm(
|
|
||||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
|
||||||
)
|
|
||||||
|
|
||||||
epoch_to_start = 0
|
epoch_to_start = 0
|
||||||
if initial_step > 0:
|
if initial_step > 0:
|
||||||
if args.skip_until_initial_step:
|
if args.skip_until_initial_step:
|
||||||
@@ -1309,6 +1305,10 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
progress_bar = tqdm(
|
||||||
|
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||||
|
)
|
||||||
|
|
||||||
validation_steps = (
|
validation_steps = (
|
||||||
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user