Fix fp16 mixed precision, model is in bf16 without full_bf16

This commit is contained in:
Kohya S
2024-06-29 17:21:25 +09:00
parent 66cf435479
commit 19086465e8
4 changed files with 61 additions and 15 deletions

View File

@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision
def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
sd3_models.MMDiT,
Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE,
]:
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
@@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device,
args.vae,
attn_mode,
accelerator.device if args.lowram else "cpu",
weight_dtype,
model_dtype,
args.disable_mmap_load_safetensors,
clip_dtype,
t5xxl_device,
t5xxl_dtype,
vae_dtype,
)
# work on low-ram device
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram:
if clip_l is not None:
clip_l.to(accelerator.device)