mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update sdxl ver in lora metadata from v0-9 to v1-0
This commit is contained in:
@@ -10,10 +10,10 @@ from library import sdxl_original_unet
|
|||||||
|
|
||||||
|
|
||||||
VAE_SCALE_FACTOR = 0.13025
|
VAE_SCALE_FACTOR = 0.13025
|
||||||
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||||
|
|
||||||
# Diffusersの設定を読み込むための参照モデル
|
# Diffusersの設定を読み込むための参照モデル
|
||||||
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-0.9" # アクセス権が必要
|
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
|
|
||||||
DIFFUSERS_SDXL_UNET_CONFIG = {
|
DIFFUSERS_SDXL_UNET_CONFIG = {
|
||||||
"act_fn": "silu",
|
"act_fn": "silu",
|
||||||
|
|||||||
@@ -61,15 +61,15 @@ def svd(args):
|
|||||||
else:
|
else:
|
||||||
print(f"loading original SDXL model : {args.model_org}")
|
print(f"loading original SDXL model : {args.model_org}")
|
||||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||||
print(f"loading original SDXL model : {args.model_tuned}")
|
print(f"loading original SDXL model : {args.model_tuned}")
|
||||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
|
||||||
)
|
)
|
||||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9
|
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||||
|
|
||||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||||
if args.conv_dim is None:
|
if args.conv_dim is None:
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ def merge(args):
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||||
|
|
||||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -1294,7 +1294,7 @@ def main(args):
|
|||||||
args.ckpt = files[0]
|
args.ckpt = files[0]
|
||||||
|
|
||||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
|
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ if __name__ == "__main__":
|
|||||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||||
# If the main RAM is small, it may be better to load it on the GPU
|
# If the main RAM is small, it may be better to load it on the GPU
|
||||||
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||||
assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
# assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
||||||
|
|
||||||
# Diffusers版のxformers使用フラグを設定する関数
|
# Diffusers版のxformers使用フラグを設定する関数
|
||||||
def set_diffusers_xformers_flag(model, valid):
|
def set_diffusers_xformers_flag(model, valid):
|
||||||
|
|||||||
@@ -32,13 +32,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||||
|
|
||||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
self.logit_scale = logit_scale
|
self.logit_scale = logit_scale
|
||||||
self.ckpt_info = ckpt_info
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def load_tokenizer(self, args):
|
||||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||||
|
|||||||
@@ -28,13 +28,13 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||||
|
|
||||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||||
self.logit_scale = logit_scale
|
self.logit_scale = logit_scale
|
||||||
self.ckpt_info = ckpt_info
|
self.ckpt_info = ckpt_info
|
||||||
|
|
||||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||||
|
|
||||||
def load_tokenizer(self, args):
|
def load_tokenizer(self, args):
|
||||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user