mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: add block swap for FLUX.1/SD3 LoRA training
This commit is contained in:
@@ -1887,7 +1887,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
# make image path to npz path mapping
|
||||
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
|
||||
npz_paths.sort(key=lambda item: item.rsplit("_", maxsplit=2)[0]) # sort by name excluding resolution and cache_suffix
|
||||
npz_paths.sort(
|
||||
key=lambda item: item.rsplit("_", maxsplit=2)[0]
|
||||
) # sort by name excluding resolution and cache_suffix
|
||||
npz_path_index = 0
|
||||
|
||||
size_set_count = 0
|
||||
@@ -3537,8 +3539,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--fused_backward_pass",
|
||||
action="store_true",
|
||||
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL"
|
||||
+ " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効",
|
||||
help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL, SD3 and FLUX"
|
||||
" / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXL、SD3、FLUXでのみ利用可能",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_timescale",
|
||||
@@ -4027,6 +4029,72 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
def add_dit_training_arguments(parser: argparse.ArgumentParser):
|
||||
# Text encoder related arguments
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_batch_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="text encoder batch size (default: None, use dataset's batch size)"
|
||||
+ " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
|
||||
)
|
||||
|
||||
# Model loading optimization
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
||||
)
|
||||
|
||||
# Training arguments. partial copy from Diffusers
|
||||
parser.add_argument(
|
||||
"--weighting_scheme",
|
||||
type=str,
|
||||
default="uniform",
|
||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none", "uniform"],
|
||||
help="weighting scheme for timestep distribution. Default is uniform, uniform and none are the same behavior"
|
||||
" / タイムステップ分布の重み付けスキーム、デフォルトはuniform、uniform と none は同じ挙動",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_mean",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="mean to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合の平均",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_std",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="std to use when using the `'logit_normal'` weighting scheme / `'logit_normal'`重み付けスキームを使用する場合のstd",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode_scale",
|
||||
type=float,
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme` / モード重み付けスキームのスケール",
|
||||
)
|
||||
|
||||
# offloading
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of blocks to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
|
||||
|
||||
def get_sanitized_config_or_none(args: argparse.Namespace):
|
||||
# if `--log_config` is enabled, return args for logging. if not, return None.
|
||||
# when `--log_config is enabled, filter out sensitive values from args
|
||||
|
||||
Reference in New Issue
Block a user