mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #791 from Isotr0py/dev
Intergrate fp16/bf16 support to sdxl model loading
This commit is contained in:
@@ -160,7 +160,7 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
|||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||||
# model_version is reserved for future use
|
# model_version is reserved for future use
|
||||||
# dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
||||||
|
|
||||||
# Load the state dict
|
# Load the state dict
|
||||||
if model_util.is_safetensors(ckpt_path):
|
if model_util.is_safetensors(ckpt_path):
|
||||||
@@ -193,7 +193,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
if k.startswith("model.diffusion_model."):
|
if k.startswith("model.diffusion_model."):
|
||||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||||
info = _load_state_dict_on_device(unet, unet_sd, device=map_location)
|
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
||||||
print("U-Net: ", info)
|
print("U-Net: ", info)
|
||||||
|
|
||||||
# Text Encoders
|
# Text Encoders
|
||||||
@@ -221,7 +221,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
# torch_dtype="float32",
|
# torch_dtype="float32",
|
||||||
# transformers_version="4.25.0.dev0",
|
# transformers_version="4.25.0.dev0",
|
||||||
)
|
)
|
||||||
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
with init_empty_weights():
|
||||||
|
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
||||||
|
|
||||||
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
||||||
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
||||||
@@ -246,7 +247,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
# torch_dtype="float32",
|
# torch_dtype="float32",
|
||||||
# transformers_version="4.25.0.dev0",
|
# transformers_version="4.25.0.dev0",
|
||||||
)
|
)
|
||||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
with init_empty_weights():
|
||||||
|
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||||
|
|
||||||
print("loading text encoders from checkpoint")
|
print("loading text encoders from checkpoint")
|
||||||
te1_sd = {}
|
te1_sd = {}
|
||||||
@@ -257,21 +259,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
elif k.startswith("conditioner.embedders.1.model."):
|
elif k.startswith("conditioner.embedders.1.model."):
|
||||||
te2_sd[k] = state_dict.pop(k)
|
te2_sd[k] = state_dict.pop(k)
|
||||||
|
|
||||||
info1 = text_model1.load_state_dict(te1_sd)
|
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||||
print("text encoder 1:", info1)
|
print("text encoder 1:", info1)
|
||||||
|
|
||||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||||
info2 = text_model2.load_state_dict(converted_sd)
|
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
||||||
print("text encoder 2:", info2)
|
print("text encoder 2:", info2)
|
||||||
|
|
||||||
# prepare vae
|
# prepare vae
|
||||||
print("building VAE")
|
print("building VAE")
|
||||||
vae_config = model_util.create_vae_diffusers_config()
|
vae_config = model_util.create_vae_diffusers_config()
|
||||||
vae = AutoencoderKL(**vae_config) # .to(device)
|
with init_empty_weights():
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
|
||||||
print("loading VAE from checkpoint")
|
print("loading VAE from checkpoint")
|
||||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
||||||
print("VAE:", info)
|
print("VAE:", info)
|
||||||
|
|
||||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||||||
|
|
||||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||||
# load models for each process
|
# load models for each process
|
||||||
|
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
@@ -36,6 +37,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
model_version,
|
model_version,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
accelerator.device if args.lowram else "cpu",
|
accelerator.device if args.lowram else "cpu",
|
||||||
|
model_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
@@ -54,7 +56,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
|
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None):
|
||||||
|
# model_dtype only work with full fp16/bf16
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
|
|
||||||
@@ -67,7 +70,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
from diffusers import StableDiffusionXLPipeline
|
from diffusers import StableDiffusionXLPipeline
|
||||||
@@ -77,7 +80,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None
|
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
||||||
)
|
)
|
||||||
except EnvironmentError as ex:
|
except EnvironmentError as ex:
|
||||||
if variant is not None:
|
if variant is not None:
|
||||||
@@ -93,6 +96,13 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
|
|
||||||
text_encoder1 = pipe.text_encoder
|
text_encoder1 = pipe.text_encoder
|
||||||
text_encoder2 = pipe.text_encoder_2
|
text_encoder2 = pipe.text_encoder_2
|
||||||
|
|
||||||
|
# convert to fp32 for cache text_encoders outputs
|
||||||
|
if text_encoder1.dtype != torch.float32:
|
||||||
|
text_encoder1 = text_encoder1.to(dtype=torch.float32)
|
||||||
|
if text_encoder2.dtype != torch.float32:
|
||||||
|
text_encoder2 = text_encoder2.to(dtype=torch.float32)
|
||||||
|
|
||||||
vae = pipe.vae
|
vae = pipe.vae
|
||||||
unet = pipe.unet
|
unet = pipe.unet
|
||||||
del pipe
|
del pipe
|
||||||
@@ -101,7 +111,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
|
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
||||||
print("U-Net converted to original U-Net")
|
print("U-Net converted to original U-Net")
|
||||||
|
|
||||||
logit_scale = None
|
logit_scale = None
|
||||||
@@ -146,6 +156,21 @@ def load_tokenizers(args: argparse.Namespace):
|
|||||||
return tokeniers
|
return tokeniers
|
||||||
|
|
||||||
|
|
||||||
|
def match_mixed_precision(args, weight_dtype):
|
||||||
|
if args.full_fp16:
|
||||||
|
assert (
|
||||||
|
weight_dtype == torch.float16
|
||||||
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
|
return weight_dtype
|
||||||
|
elif args.full_bf16:
|
||||||
|
assert (
|
||||||
|
weight_dtype == torch.bfloat16
|
||||||
|
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||||
|
return weight_dtype
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
"""
|
"""
|
||||||
Create sinusoidal timestep embeddings.
|
Create sinusoidal timestep embeddings.
|
||||||
|
|||||||
Reference in New Issue
Block a user