load models one by one

This commit is contained in:
Kohya S
2024-07-08 22:04:43 +09:00
parent c9de7c4e9a
commit 3ea4fce5e0
3 changed files with 236 additions and 47 deletions

View File

@@ -1,19 +1,17 @@
import argparse
import math
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
from safetensors.torch import save_file
from accelerate import Accelerator
from library import sd3_models, sd3_utils, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
# from transformers import CLIPTokenizer
# from library import model_util
# , sdxl_model_util, train_util, sdxl_original_unet
@@ -28,50 +26,48 @@ logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
def load_target_model(
model_type: str,
args: argparse.Namespace,
state_dict: dict,
accelerator: Accelerator,
attn_mode: str,
model_dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> Union[
sd3_models.MMDiT,
Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE,
]:
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models(
args.pretrained_model_name_or_path,
args.clip_l,
args.clip_g,
args.t5xxl,
args.vae,
attn_mode,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors,
clip_dtype,
t5xxl_device,
t5xxl_dtype,
vae_dtype,
)
if model_type == "mmdit":
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
elif model_type == "clip_l":
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
elif model_type == "clip_g":
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
elif model_type == "t5xxl":
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
elif model_type == "vae":
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
else:
raise ValueError(f"Unknown model type: {model_type}")
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram:
if clip_l is not None:
clip_l.to(accelerator.device)
if clip_g is not None:
clip_g.to(accelerator.device)
if t5xxl is not None:
t5xxl.to(accelerator.device)
vae.to(accelerator.device)
mmdit.to(accelerator.device)
model = model.to(accelerator.device)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return mmdit, clip_l, clip_g, t5xxl, vae
return model
def save_models(