mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge pull request #2060 from saibit-tech/sd3
Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16
This commit is contained in:
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
|||||||
|
|
||||||
from .utils import setup_logging
|
from .utils import setup_logging
|
||||||
|
|
||||||
|
from .device_utils import get_preferred_device
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
|||||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||||
)
|
)
|
||||||
|
|
||||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||||
if args.mixed_precision.lower() == "fp16":
|
if args.mixed_precision.lower() == "fp16":
|
||||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||||
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
class DeepSpeedWrapper(torch.nn.Module):
|
class DeepSpeedWrapper(torch.nn.Module):
|
||||||
def __init__(self, **kw_models) -> None:
|
def __init__(self, **kw_models) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.models = torch.nn.ModuleDict()
|
self.models = torch.nn.ModuleDict()
|
||||||
|
|
||||||
|
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||||
|
|
||||||
for key, model in kw_models.items():
|
for key, model in kw_models.items():
|
||||||
if isinstance(model, list):
|
if isinstance(model, list):
|
||||||
model = torch.nn.ModuleList(model)
|
model = torch.nn.ModuleList(model)
|
||||||
|
|
||||||
|
if wrap_model_forward_with_torch_autocast:
|
||||||
|
model = self.__wrap_model_with_torch_autocast(model)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model, torch.nn.Module
|
model, torch.nn.Module
|
||||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||||
|
|
||||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||||
|
|
||||||
|
def __wrap_model_with_torch_autocast(self, model):
|
||||||
|
if isinstance(model, torch.nn.ModuleList):
|
||||||
|
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||||
|
else:
|
||||||
|
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||||
|
|
||||||
|
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||||
|
|
||||||
|
forward_fn = model.forward
|
||||||
|
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
device_type = model.device.type
|
||||||
|
except AttributeError:
|
||||||
|
logger.warning(
|
||||||
|
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||||
|
"to determine the device_type for torch.autocast()."
|
||||||
|
)
|
||||||
|
device_type = get_preferred_device().type
|
||||||
|
|
||||||
|
with torch.autocast(device_type = device_type):
|
||||||
|
return forward_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
model.forward = forward
|
||||||
|
return model
|
||||||
|
|
||||||
def get_models(self):
|
def get_models(self):
|
||||||
return self.models
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
ds_model = DeepSpeedWrapper(**models)
|
ds_model = DeepSpeedWrapper(**models)
|
||||||
return ds_model
|
return ds_model
|
||||||
|
|||||||
@@ -5498,6 +5498,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
|||||||
|
|
||||||
|
|
||||||
def patch_accelerator_for_fp16_training(accelerator):
|
def patch_accelerator_for_fp16_training(accelerator):
|
||||||
|
|
||||||
|
from accelerate import DistributedType
|
||||||
|
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||||
|
return
|
||||||
|
|
||||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||||
|
|
||||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
|||||||
Reference in New Issue
Block a user