mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix to work full_bf16 and full_fp16.
This commit is contained in:
@@ -891,6 +891,14 @@ class MMDiT(nn.Module):
|
||||
def model_type(self):
|
||||
return "m" # only support medium
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.gradient_checkpointing = True
|
||||
for block in self.joint_blocks:
|
||||
|
||||
@@ -28,7 +28,7 @@ def load_models(
|
||||
vae_path: str,
|
||||
attn_mode: str,
|
||||
device: Union[str, torch.device],
|
||||
default_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
weight_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
disable_mmap: bool = False,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
t5xxl_device: Optional[Union[str, torch.device]] = None,
|
||||
@@ -46,7 +46,7 @@ def load_models(
|
||||
vae_path: Path to the VAE checkpoint file.
|
||||
attn_mode: Attention mode for MMDiT model.
|
||||
device: Device for MMDiT model.
|
||||
default_dtype: Default dtype for each model. In training, it's usually None. None means using float32.
|
||||
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
|
||||
disable_mmap: Disable memory mapping when loading state dict.
|
||||
clip_dtype: Dtype for Clip models, or None to use default dtype.
|
||||
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
|
||||
@@ -61,8 +61,6 @@ def load_models(
|
||||
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
|
||||
# Therefore, we need clip_dtype and t5xxl_dtype.
|
||||
|
||||
# default_dtype is used for full_fp16/full_bf16 training.
|
||||
|
||||
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
@@ -73,9 +71,9 @@ def load_models(
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
t5xxl_device = t5xxl_device or device
|
||||
clip_dtype = clip_dtype or default_dtype or torch.float32
|
||||
t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32
|
||||
vae_dtype = vae_dtype or default_dtype or torch.float32
|
||||
clip_dtype = clip_dtype or weight_dtype or torch.float32
|
||||
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
|
||||
vae_dtype = vae_dtype or weight_dtype or torch.float32
|
||||
|
||||
logger.info(f"Loading SD3 models from {ckpt_path}...")
|
||||
state_dict = load_state_dict(ckpt_path)
|
||||
@@ -157,7 +155,7 @@ def load_models(
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype)
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
|
||||
# load ClipG and ClipL
|
||||
|
||||
Reference in New Issue
Block a user