intergrate fp16/bf16 to model loading

This commit is contained in:
Isotr0py
2023-08-27 19:16:23 +08:00
parent 3d12cdc643
commit 6155f9c171
2 changed files with 40 additions and 11 deletions

View File

@@ -18,6 +18,7 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
@@ -36,6 +37,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype
)
# work on low-ram device
@@ -54,7 +56,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -67,7 +70,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
@@ -77,7 +80,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
)
except EnvironmentError as ex:
if variant is not None:
@@ -93,6 +96,13 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
text_encoder1 = pipe.text_encoder
text_encoder2 = pipe.text_encoder_2
# convert to fp32 for cache text_encoders outputs
if text_encoder1.dtype != torch.float32:
text_encoder1 = text_encoder1.to(dtype=torch.float32)
if text_encoder2.dtype != torch.float32:
text_encoder2 = text_encoder2.to(dtype=torch.float32)
vae = pipe.vae
unet = pipe.unet
del pipe
@@ -101,7 +111,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device)
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")
logit_scale = None
@@ -146,6 +156,21 @@ def load_tokenizers(args: argparse.Namespace):
return tokeniers
def match_mixed_precision(args, weight_dtype):
if args.full_fp16:
assert (
weight_dtype == torch.float16
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
return weight_dtype
elif args.full_bf16:
assert (
weight_dtype == torch.bfloat16
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
return weight_dtype
else:
return None
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.