Support Lora

This commit is contained in:
kohya-ss
2024-10-29 21:52:04 +09:00
parent 0af4edd8a6
commit d4e19fbd5e

View File

@@ -10,11 +10,13 @@ import numpy as np
import torch import torch
from safetensors.torch import safe_open, load_file from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex() init_ipex()
@@ -104,6 +106,7 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0) x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float() model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
@@ -153,7 +156,7 @@ def generate_image(
clip_g.to(device) clip_g.to(device)
t5xxl.to(device) t5xxl.to(device)
with torch.no_grad(): with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt) tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
@@ -233,13 +236,14 @@ if __name__ == "__main__":
parser.add_argument("--bf16", action="store_true") parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50) parser.add_argument("--steps", type=int, default=50)
# parser.add_argument( parser.add_argument(
# "--lora_weights", "--lora_weights",
# type=str, type=str,
# nargs="*", nargs="*",
# default=[], default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
# ) )
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true")
@@ -294,6 +298,30 @@ if __name__ == "__main__":
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive: if not args.interactive:
generate_image( generate_image(
mmdit, mmdit,
@@ -344,13 +372,13 @@ if __name__ == "__main__":
steps = int(opt[1:].strip()) steps = int(opt[1:].strip())
elif opt.startswith("d"): elif opt.startswith("d"):
seed = int(opt[1:].strip()) seed = int(opt[1:].strip())
# elif opt.startswith("m"): elif opt.startswith("m"):
# mutipliers = opt[1:].strip().split(",") mutipliers = opt[1:].strip().split(",")
# if len(mutipliers) != len(lora_models): if len(mutipliers) != len(lora_models):
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
# continue continue
# for i, lora_model in enumerate(lora_models): for i, lora_model in enumerate(lora_models):
# lora_model.set_multiplier(float(mutipliers[i])) lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"): elif opt.startswith("n"):
negative_prompt = opt[1:].strip() negative_prompt = opt[1:].strip()
if negative_prompt == "-": if negative_prompt == "-":