OFT for FLUX.1

This commit is contained in:
Kohya S
2024-09-14 15:48:16 +09:00
parent 0485f236a0
commit 2d8ee3c280
4 changed files with 504 additions and 6 deletions

View File

@@ -14,9 +14,11 @@ from tqdm import tqdm
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from safetensors.torch import load_file
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from networks import oft_flux
init_ipex()
@@ -405,7 +407,7 @@ if __name__ == "__main__":
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
@@ -482,9 +484,19 @@ if __name__ == "__main__":
else:
multiplier = 1.0
lora_model, weights_sd = lora_flux.create_network_from_weights(
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
)
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break
module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else: