From 7c61c0dfe0e879fd6b66ccb70273e4b99deaf1c5 Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:06:55 +0800 Subject: [PATCH 1/8] Add autocast warpper for forward functions in deepspeed_utils.py to try aligning precision when using mixed precision in training process --- library/deepspeed_utils.py | 32 ++++++++++++++++++++++++++++++++ library/flux_models.py | 2 +- library/train_util.py | 5 +++++ requirements.txt | 3 ++- 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 99a7b2b3..3018def7 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -94,6 +94,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace): deepspeed_plugin.deepspeed_config["train_batch_size"] = ( args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"]) ) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) if args.mixed_precision.lower() == "fp16": deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow. @@ -122,18 +123,49 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): class DeepSpeedWrapper(torch.nn.Module): def __init__(self, **kw_models) -> None: super().__init__() + self.models = torch.nn.ModuleDict() + + warp_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): model = torch.nn.ModuleList(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + + if warp_model_forward_with_torch_autocast: + model = self.__warp_with_torch_autocast(model) + self.models.update(torch.nn.ModuleDict({key: model})) + def __warp_with_torch_autocast(self, model): + if isinstance(model, torch.nn.ModuleList): + for i in range(len(model)): + model[i] = self.__warp_model_forward_with_torch_autocast(model[i]) + else: + model = self.__warp_model_forward_with_torch_autocast(model) + return model + + def __warp_model_forward_with_torch_autocast(self, model): + + assert hasattr(model, "forward"), f"model must have a forward method." + + forward_fn = model.forward + + def forward(*args, **kwargs): + device_type= "cuda" if torch.cuda.is_available() else "cpu" + with torch.autocast(device_type=device_type): + return forward_fn(*args, **kwargs) + model.forward = forward + + return model + def get_models(self): return self.models + ds_model = DeepSpeedWrapper(**models) return ds_model diff --git a/library/flux_models.py b/library/flux_models.py index 328ad481..12151ee8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1005,7 +1005,7 @@ class Flux(nn.Module): return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - + def forward( self, img: Tensor, diff --git a/library/train_util.py b/library/train_util.py index 6c39f8d9..dbbfda3e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5495,6 +5495,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio def patch_accelerator_for_fp16_training(accelerator): + + from accelerate import DistributedType + if accelerator.distributed_type == DistributedType.DEEPSPEED: + return + org_unscale_grads = accelerator.scaler._unscale_grads_ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): diff --git a/requirements.txt b/requirements.txt index 767d9e8e..bead3f90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ accelerate==0.33.0 transformers==4.44.0 -diffusers[torch]==0.25.0 +diffusers==0.25.0 +deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78 From d33d5eccd16970e489359ee02b89a6259559e4b9 Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:12:06 +0800 Subject: [PATCH 2/8] # --- library/flux_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index 12151ee8..d7840d51 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1004,8 +1004,7 @@ class Flux(nn.Module): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) - + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) def forward( self, img: Tensor, From 7f984f47758f9e17f4a82b92cb9dbc97b3ba982f Mon Sep 17 00:00:00 2001 From: saibit Date: Tue, 22 Apr 2025 16:15:12 +0800 Subject: [PATCH 3/8] # --- library/flux_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/flux_models.py b/library/flux_models.py index d7840d51..328ad481 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1004,7 +1004,8 @@ class Flux(nn.Module): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) - self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + def forward( self, img: Tensor, From c8af252a44a7dbc54a0c1622946faedef4e7c52b Mon Sep 17 00:00:00 2001 From: Robert Date: Tue, 22 Apr 2025 16:19:14 +0800 Subject: [PATCH 4/8] refactor --- library/deepspeed_utils.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 3018def7..f6eac367 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -126,7 +126,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): self.models = torch.nn.ModuleDict() - warp_model_forward_with_torch_autocast = args.mixed_precision is not "no" + wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no" for key, model in kw_models.items(): if isinstance(model, list): @@ -136,31 +136,30 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - if warp_model_forward_with_torch_autocast: - model = self.__warp_with_torch_autocast(model) + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) self.models.update(torch.nn.ModuleDict({key: model})) - def __warp_with_torch_autocast(self, model): + def __wrap_model_with_torch_autocast(self, model): if isinstance(model, torch.nn.ModuleList): - for i in range(len(model)): - model[i] = self.__warp_model_forward_with_torch_autocast(model[i]) + model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model] else: - model = self.__warp_model_forward_with_torch_autocast(model) + model = self.__wrap_model_forward_with_torch_autocast(model) return model - def __warp_model_forward_with_torch_autocast(self, model): + def __wrap_model_forward_with_torch_autocast(self, model): assert hasattr(model, "forward"), f"model must have a forward method." forward_fn = model.forward def forward(*args, **kwargs): - device_type= "cuda" if torch.cuda.is_available() else "cpu" + device_type = "cuda" if torch.cuda.is_available() else "cpu" with torch.autocast(device_type=device_type): return forward_fn(*args, **kwargs) + model.forward = forward - return model def get_models(self): From adb775c6165d93a856e33d0d9058efd629cf2a2d Mon Sep 17 00:00:00 2001 From: saibit Date: Wed, 23 Apr 2025 17:05:20 +0800 Subject: [PATCH 5/8] Update: requirement diffusers[torch]==0.25.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bead3f90..9e97eed3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ accelerate==0.33.0 transformers==4.44.0 -diffusers==0.25.0 +diffusers[torch]==0.25.0 deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 From abf2c44bc5650afef8bebbb1ef278c66f44c4dda Mon Sep 17 00:00:00 2001 From: sharlynxy Date: Wed, 23 Apr 2025 18:57:19 +0800 Subject: [PATCH 6/8] Dynamically set device in deepspeed wrapper (#2) * get device type from model * add logger warning * format * format * format --- library/deepspeed_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index f6eac367..09c6f7b9 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator from .utils import setup_logging +from .device_utils import get_preferred_device + setup_logging() import logging @@ -153,13 +155,21 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): assert hasattr(model, "forward"), f"model must have a forward method." forward_fn = model.forward - + def forward(*args, **kwargs): - device_type = "cuda" if torch.cuda.is_available() else "cpu" - with torch.autocast(device_type=device_type): + try: + device_type = model.device.type + except AttributeError: + logger.warning( + "[DeepSpeed] model.device is not available. Using get_preferred_device() " + "to determine the device_type for torch.autocast()." + ) + device_type = get_preferred_device().type + + with torch.autocast(device_type = device_type): return forward_fn(*args, **kwargs) - model.forward = forward + model.forward = forward return model def get_models(self): From 46ad3be0593df1df9d485c3ac2efb5aebd87730c Mon Sep 17 00:00:00 2001 From: saibit Date: Thu, 24 Apr 2025 11:26:36 +0800 Subject: [PATCH 7/8] update deepspeed wrapper --- library/deepspeed_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/deepspeed_utils.py b/library/deepspeed_utils.py index 09c6f7b9..a8a05c3a 100644 --- a/library/deepspeed_utils.py +++ b/library/deepspeed_utils.py @@ -134,18 +134,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models): if isinstance(model, list): model = torch.nn.ModuleList(model) + if wrap_model_forward_with_torch_autocast: + model = self.__wrap_model_with_torch_autocast(model) + assert isinstance( model, torch.nn.Module ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" - if wrap_model_forward_with_torch_autocast: - model = self.__wrap_model_with_torch_autocast(model) - self.models.update(torch.nn.ModuleDict({key: model})) def __wrap_model_with_torch_autocast(self, model): if isinstance(model, torch.nn.ModuleList): - model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model] + model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model]) else: model = self.__wrap_model_forward_with_torch_autocast(model) return model From 1684ababcd7fc4259c77f1471ef41d10e612a721 Mon Sep 17 00:00:00 2001 From: sharlynxy Date: Wed, 30 Apr 2025 19:51:09 +0800 Subject: [PATCH 8/8] remove deepspeed from requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9e97eed3..767d9e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ accelerate==0.33.0 transformers==4.44.0 diffusers[torch]==0.25.0 -deepspeed==0.16.7 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.8.1.78