mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
add dtype to u-net loading
This commit is contained in:
@@ -135,7 +135,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype):
|
||||
# model_version is reserved for future use
|
||||
|
||||
# Load the state dict
|
||||
@@ -167,7 +167,9 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
print("loading U-Net from checkpoint")
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
set_module_tensor_to_device(unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k))
|
||||
set_module_tensor_to_device(
|
||||
unet, k.replace("model.diffusion_model.", ""), map_location, value=state_dict.pop(k), dtype=dtype
|
||||
)
|
||||
# TODO: catch missing_keys and unexpected_keys with _IncompatibleKeys
|
||||
# print("U-Net: ", info)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user