Fix to work full_bf16 and full_fp16.

This commit is contained in:
Kohya S
2024-06-29 17:45:50 +09:00
parent 19086465e8
commit ea18d5ba6d
3 changed files with 24 additions and 18 deletions

View File

@@ -891,6 +891,14 @@ class MMDiT(nn.Module):
def model_type(self):
return "m" # only support medium
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.joint_blocks:

View File

@@ -28,7 +28,7 @@ def load_models(
vae_path: str,
attn_mode: str,
device: Union[str, torch.device],
default_dtype: Optional[Union[str, torch.dtype]] = None,
weight_dtype: Optional[Union[str, torch.dtype]] = None,
disable_mmap: bool = False,
clip_dtype: Optional[Union[str, torch.dtype]] = None,
t5xxl_device: Optional[Union[str, torch.device]] = None,
@@ -46,7 +46,7 @@ def load_models(
vae_path: Path to the VAE checkpoint file.
attn_mode: Attention mode for MMDiT model.
device: Device for MMDiT model.
default_dtype: Default dtype for each model. In training, it's usually None. None means using float32.
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
disable_mmap: Disable memory mapping when loading state dict.
clip_dtype: Dtype for Clip models, or None to use default dtype.
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
@@ -61,8 +61,6 @@ def load_models(
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
# Therefore, we need clip_dtype and t5xxl_dtype.
# default_dtype is used for full_fp16/full_bf16 training.
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
@@ -73,9 +71,9 @@ def load_models(
return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device
clip_dtype = clip_dtype or default_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32
vae_dtype = vae_dtype or default_dtype or torch.float32
clip_dtype = clip_dtype or weight_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
vae_dtype = vae_dtype or weight_dtype or torch.float32
logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path)
@@ -157,7 +155,7 @@ def load_models(
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype)
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL

View File

@@ -182,7 +182,7 @@ def train(args):
raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}")
t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device
clip_dtype = weight_dtype # if not args.train_text_encoder else None
clip_dtype = weight_dtype # if not args.train_text_encoder else None
# モデルを読み込む
attn_mode = "xformers" if args.xformers else "torch"
@@ -193,7 +193,7 @@ def train(args):
# models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
)
assert clip_l is not None, "clip_l is required / clip_lは必須です"
assert clip_g is not None, "clip_g is required / clip_gは必須です"
@@ -769,10 +769,10 @@ def train(args):
epoch,
num_train_epochs,
global_step,
clip_l if args.save_clip else None,
clip_g if args.save_clip else None,
t5xxl if args.save_t5xxl else None,
mmdit,
accelerator.unwrap_model(clip_l) if args.save_clip else None,
accelerator.unwrap_model(clip_g) if args.save_clip else None,
accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None,
accelerator.unwrap_model(mmdit),
vae,
)
@@ -807,10 +807,10 @@ def train(args):
epoch,
num_train_epochs,
global_step,
clip_l if args.save_clip else None,
clip_g if args.save_clip else None,
t5xxl if args.save_t5xxl else None,
mmdit,
accelerator.unwrap_model(clip_l) if args.save_clip else None,
accelerator.unwrap_model(clip_g) if args.save_clip else None,
accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None,
accelerator.unwrap_model(mmdit),
vae,
)