Fix fp16 mixed precision, model is in bf16 without full_bf16

This commit is contained in:
Kohya S
2024-06-29 17:21:25 +09:00
parent 66cf435479
commit 19086465e8
4 changed files with 61 additions and 15 deletions

View File

@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision
def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
sd3_models.MMDiT,
Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE,
]:
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
@@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device,
args.vae,
attn_mode,
accelerator.device if args.lowram else "cpu",
weight_dtype,
model_dtype,
args.disable_mmap_load_safetensors,
clip_dtype,
t5xxl_device,
t5xxl_dtype,
vae_dtype,
)
# work on low-ram device
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram:
if clip_l is not None:
clip_l.to(accelerator.device)

View File

@@ -28,11 +28,41 @@ def load_models(
vae_path: str,
attn_mode: str,
device: Union[str, torch.device],
weight_dtype: torch.dtype,
default_dtype: Optional[Union[str, torch.dtype]] = None,
disable_mmap: bool = False,
t5xxl_device: Optional[str] = None,
t5xxl_dtype: Optional[str] = None,
clip_dtype: Optional[Union[str, torch.dtype]] = None,
t5xxl_device: Optional[Union[str, torch.device]] = None,
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
vae_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""
Load SD3 models from checkpoint files.
Args:
ckpt_path: Path to the SD3 checkpoint file.
clip_l_path: Path to the clip_l checkpoint file.
clip_g_path: Path to the clip_g checkpoint file.
t5xxl_path: Path to the t5xxl checkpoint file.
vae_path: Path to the VAE checkpoint file.
attn_mode: Attention mode for MMDiT model.
device: Device for MMDiT model.
default_dtype: Default dtype for each model. In training, it's usually None. None means using float32.
disable_mmap: Disable memory mapping when loading state dict.
clip_dtype: Dtype for Clip models, or None to use default dtype.
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
vae_dtype: Dtype for VAE model, or None to use default dtype.
Returns:
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
"""
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
# Therefore, we need clip_dtype and t5xxl_dtype.
# default_dtype is used for full_fp16/full_bf16 training.
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
@@ -43,6 +73,9 @@ def load_models(
return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device
clip_dtype = clip_dtype or default_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32
vae_dtype = vae_dtype or default_dtype or torch.float32
logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path)
@@ -124,7 +157,7 @@ def load_models(
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype)
logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL
@@ -132,7 +165,7 @@ def load_models(
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd)
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
@@ -142,7 +175,7 @@ def load_models(
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd)
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
@@ -165,6 +198,7 @@ def load_models(
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return mmdit, clip_l, clip_g, t5xxl, vae