mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
add dtype to u-net loading
This commit is contained in:
@@ -135,7 +135,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
return new_sd, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype):
|
||||||
# model_version is reserved for future use
|
# model_version is reserved for future use
|
||||||
|
|
||||||
# Load the state dict
|
# Load the state dict
|
||||||
@@ -167,7 +167,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
print("loading U-Net from checkpoint")
|
print("loading U-Net from checkpoint")
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
if k.startswith("model.diffusion_model."):
|
if k.startswith("model.diffusion_model."):
|
||||||
set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k))
|
set_module_tensor_to_device(
|
||||||
|
unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype
|
||||||
|
)
|
||||||
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
|
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
|
||||||
# print("U-Net: ", info)
|
# print("U-Net: ", info)
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
|
|
||||||
|
|
||||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||||
|
# TODO: integrate full fp16/bf16 to model loading
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
@@ -67,7 +68,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||||
@@ -98,7 +99,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
|||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k))
|
set_module_tensor_to_device(unet, k, device, value=state_dict.pop(k), dtype=weight_dtype)
|
||||||
print("U-Net converted to original U-Net")
|
print("U-Net converted to original U-Net")
|
||||||
|
|
||||||
logit_scale = None
|
logit_scale = None
|
||||||
|
|||||||
Reference in New Issue
Block a user