mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix fp16 mixed precision, model is in bf16 without full_bf16
This commit is contained in:
@@ -28,14 +28,14 @@ logger = logging.getLogger(__name__)
|
||||
from .sdxl_train_util import match_mixed_precision
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device, t5xxl_dtype) -> Tuple[
|
||||
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
|
||||
sd3_models.MMDiT,
|
||||
Optional[sd3_models.SDClipModel],
|
||||
Optional[sd3_models.SDXLClipG],
|
||||
Optional[sd3_models.T5XXLModel],
|
||||
sd3_models.SDVAE,
|
||||
]:
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
|
||||
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
@@ -49,13 +49,15 @@ def load_target_model(args, accelerator, attn_mode, weight_dtype, t5xxl_device,
|
||||
args.vae,
|
||||
attn_mode,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
weight_dtype,
|
||||
model_dtype,
|
||||
args.disable_mmap_load_safetensors,
|
||||
clip_dtype,
|
||||
t5xxl_device,
|
||||
t5xxl_dtype,
|
||||
vae_dtype,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
||||
if args.lowram:
|
||||
if clip_l is not None:
|
||||
clip_l.to(accelerator.device)
|
||||
|
||||
Reference in New Issue
Block a user