diff --git a/library/sd3_models.py b/library/sd3_models.py index c19aec6a..7041420c 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -891,6 +891,14 @@ class MMDiT(nn.Module): def model_type(self): return "m" # only support medium + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + def enable_gradient_checkpointing(self): self.gradient_checkpointing = True for block in self.joint_blocks: diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 45b49b04..9dc9e796 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -28,7 +28,7 @@ def load_models( vae_path: str, attn_mode: str, device: Union[str, torch.device], - default_dtype: Optional[Union[str, torch.dtype]] = None, + weight_dtype: Optional[Union[str, torch.dtype]] = None, disable_mmap: bool = False, clip_dtype: Optional[Union[str, torch.dtype]] = None, t5xxl_device: Optional[Union[str, torch.device]] = None, @@ -46,7 +46,7 @@ def load_models( vae_path: Path to the VAE checkpoint file. attn_mode: Attention mode for MMDiT model. device: Device for MMDiT model. - default_dtype: Default dtype for each model. In training, it's usually None. None means using float32. + weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different. disable_mmap: Disable memory mapping when loading state dict. clip_dtype: Dtype for Clip models, or None to use default dtype. t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device. @@ -61,8 +61,6 @@ def load_models( # However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict. # Therefore, we need clip_dtype and t5xxl_dtype. - # default_dtype is used for full_fp16/full_bf16 training. - def load_state_dict(path: str, dvc: Union[str, torch.device] = device): if disable_mmap: return safetensors.torch.load(open(path, "rb").read()) @@ -73,9 +71,9 @@ def load_models( return load_file(path) # prevent device invalid Error t5xxl_device = t5xxl_device or device - clip_dtype = clip_dtype or default_dtype or torch.float32 - t5xxl_dtype = t5xxl_dtype or default_dtype or torch.float32 - vae_dtype = vae_dtype or default_dtype or torch.float32 + clip_dtype = clip_dtype or weight_dtype or torch.float32 + t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32 + vae_dtype = vae_dtype or weight_dtype or torch.float32 logger.info(f"Loading SD3 models from {ckpt_path}...") state_dict = load_state_dict(ckpt_path) @@ -157,7 +155,7 @@ def load_models( mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) logger.info("Loading state dict...") - info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, default_dtype) + info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype) logger.info(f"Loaded MMDiT: {info}") # load ClipG and ClipL diff --git a/sd3_train.py b/sd3_train.py index bd30cdc7..de763ac6 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -182,7 +182,7 @@ def train(args): raise ValueError(f"unexpected t5xxl_dtype: {args.t5xxl_dtype}") t5xxl_device = accelerator.device if args.t5xxl_device is None else args.t5xxl_device - clip_dtype = weight_dtype # if not args.train_text_encoder else None + clip_dtype = weight_dtype # if not args.train_text_encoder else None # モデルを読み込む attn_mode = "xformers" if args.xformers else "torch" @@ -193,7 +193,7 @@ def train(args): # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, None, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype ) assert clip_l is not None, "clip_l is required / clip_lは必須です" assert clip_g is not None, "clip_g is required / clip_gは必須です" @@ -769,10 +769,10 @@ def train(args): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, ) @@ -807,10 +807,10 @@ def train(args): epoch, num_train_epochs, global_step, - clip_l if args.save_clip else None, - clip_g if args.save_clip else None, - t5xxl if args.save_t5xxl else None, - mmdit, + accelerator.unwrap_model(clip_l) if args.save_clip else None, + accelerator.unwrap_model(clip_g) if args.save_clip else None, + accelerator.unwrap_model(t5xxl) if args.save_t5xxl else None, + accelerator.unwrap_model(mmdit), vae, )