mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add sdxl fine-tuning and LoRA
This commit is contained in:
@@ -11,10 +11,13 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import sdxl_model_util
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
from PIL import Image
|
||||
import open_clip
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from library import model_util, sdxl_model_util
|
||||
import networks.lora as lora
|
||||
|
||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||
# scheduler: The settings around here seem to be the same as SD1/2
|
||||
@@ -85,6 +88,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# HuggingFaceのmodel id
|
||||
@@ -97,7 +107,7 @@ if __name__ == "__main__":
|
||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||
# If the main RAM is small, it may be better to load it on the GPU
|
||||
text_model1, text_model2, vae, unet, text_projection, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
"sdxl_base_v0-9", args.ckpt_path, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
||||
)
|
||||
|
||||
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||
@@ -134,6 +144,19 @@ if __name__ == "__main__":
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
|
||||
# LoRA
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
lora_model, weights_sd = lora.create_network_from_weights(
|
||||
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
|
||||
)
|
||||
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
||||
|
||||
# prepare embedding
|
||||
with torch.no_grad():
|
||||
# vector
|
||||
@@ -248,7 +271,7 @@ if __name__ == "__main__":
|
||||
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# latents = 1 / 0.18215 * latents
|
||||
latents = 1 / 0.13025 * latents
|
||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||
latents = latents.to(torch.float32)
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
Reference in New Issue
Block a user