feat: add multi backend attention and related update for HI2.1 models and scripts

This commit is contained in:
Kohya S
2025-09-20 19:45:33 +09:00
parent f834b2e0d4
commit b090d15f7d
6 changed files with 286 additions and 102 deletions

View File

@@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace:
"--attn_mode",
type=str,
default="torch",
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
help="attention mode",
)
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
@@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace:
if args.lycoris and not lycoris_available:
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
return args
@@ -265,7 +268,7 @@ def load_dit_model(
device,
args.dit,
args.attn_mode,
False,
True, # enable split_attn to trim masked tokens
loading_device,
loading_weight_dtype,
args.fp8_scaled and not args.lycoris,