mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
feat: add block swap for FLUX.1/SD3 LoRA training
This commit is contained in:
@@ -257,14 +257,9 @@ def sample_image_inference(
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(
|
||||
image,
|
||||
caption=prompt # positive prompt as a caption
|
||||
)},
|
||||
commit=False
|
||||
)
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
@@ -324,7 +319,7 @@ def denoise(
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
@@ -549,44 +544,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
action="store_true",
|
||||
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_batch_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="text encoder batch size (default: None, use dataset's batch size)"
|
||||
+ " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
"--weighting_scheme",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||
)
|
||||
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
||||
parser.add_argument(
|
||||
"--mode_scale",
|
||||
type=float,
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
|
||||
Reference in New Issue
Block a user