mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: add multi backend attention and related update for HI2.1 models and scripts
This commit is contained in:
@@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace:
|
||||
"--attn_mode",
|
||||
type=str,
|
||||
default="torch",
|
||||
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
|
||||
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
|
||||
help="attention mode",
|
||||
)
|
||||
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
||||
@@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace:
|
||||
if args.lycoris and not lycoris_available:
|
||||
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -265,7 +268,7 @@ def load_dit_model(
|
||||
device,
|
||||
args.dit,
|
||||
args.attn_mode,
|
||||
False,
|
||||
True, # enable split_attn to trim masked tokens
|
||||
loading_device,
|
||||
loading_weight_dtype,
|
||||
args.fp8_scaled and not args.lycoris,
|
||||
|
||||
Reference in New Issue
Block a user