add sdxl fine-tuning and LoRA

This commit is contained in:
Kohya S
2023-06-26 08:07:24 +09:00
parent 9e9df2b501
commit 747af145ed
11 changed files with 2442 additions and 754 deletions

View File

@@ -11,10 +11,13 @@ import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import sdxl_model_util
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
@@ -85,6 +88,13 @@ if __name__ == "__main__":
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
)
args = parser.parse_args()
# HuggingFaceのmodel id
@@ -97,7 +107,7 @@ if __name__ == "__main__":
# 本体RAMが少ない場合はGPUにロードするといいかも
# If the main RAM is small, it may be better to load it on the GPU
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
"sdxl_base_v0-9", args.ckpt_path, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
)
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
@@ -134,6 +144,19 @@ if __name__ == "__main__":
unet.set_use_memory_efficient_attention(True, False)
# LoRA
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
lora_model, weights_sd = lora.create_network_from_weights(
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
)
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
# prepare embedding
with torch.no_grad():
# vector
@@ -248,7 +271,7 @@ if __name__ == "__main__":
latents = scheduler.step(noise_pred, t, latents).prev_sample
# latents = 1 / 0.18215 * latents
latents = 1 / 0.13025 * latents
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
latents = latents.to(torch.float32)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)