feat: add block swap for FLUX.1/SD3 LoRA training

This commit is contained in:
Kohya S
2024-11-12 21:39:13 +09:00
parent 17cf249d76
commit 2cb7a6db02
14 changed files with 288 additions and 629 deletions

View File

@@ -1887,7 +1887,9 @@ class DreamBoothDataset(BaseDataset):
# make image path to npz path mapping
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix
npz_paths.sort(
key=lambda item: item.rsplit("_", maxsplit=2)[0]
) # sort by name excluding resolution and cache_suffix
npz_path_index = 0
size_set_count = 0
@@ -3537,8 +3539,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--fused_backward_pass",
action="store_true",
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX"
" / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能",
)
parser.add_argument(
"--lr_scheduler_timescale",
@@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
)
def add_dit_training_arguments(parser: argparse.ArgumentParser):
# Text encoder related arguments
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--text_encoder_batch_size",
type=int,
default=None,
help="text encoder batch size (default: None, use dataset's batch size)"
+ " / text encoderのバッチサイズデフォルト: None, データセットのバッチサイズを使用)",
)
# Model loading optimization
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
# Training arguments. partial copy from Diffusers
parser.add_argument(
"--weighting_scheme",
type=str,
default="uniform",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"],
help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior"
" / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動",
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
)
# offloading
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
# when `--log_config is enabled, filter out sensitive values from args