Update train_network.py

This commit is contained in:
gesen2egee
2024-10-07 18:01:00 +08:00
parent 012e7e63a5
commit f09824fd31

View File

@@ -205,6 +205,9 @@ class NetworkTrainer:
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
@@ -1162,6 +1165,11 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
parser.add_argument(
"--no_token_padding",
action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作",
)
parser.add_argument(
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
)