mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
OFT for FLUX.1
This commit is contained in:
@@ -14,9 +14,11 @@ from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import accelerate
|
||||
from transformers import CLIPTextModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from networks import oft_flux
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -405,7 +407,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument("--width", type=int, default=target_width)
|
||||
@@ -482,9 +484,19 @@ if __name__ == "__main__":
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
lora_model, weights_sd = lora_flux.create_network_from_weights(
|
||||
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
|
||||
)
|
||||
weights_sd = load_file(weights_file)
|
||||
is_lora = is_oft = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith("lora"):
|
||||
is_lora = True
|
||||
if key.startswith("oft"):
|
||||
is_oft = True
|
||||
if is_lora or is_oft:
|
||||
break
|
||||
|
||||
module = lora_flux if is_lora else oft_flux
|
||||
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user