mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 00:17:18 +00:00
refactor
This commit is contained in:
@@ -126,7 +126,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
warp_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
@@ -136,31 +136,30 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
if warp_model_forward_with_torch_autocast:
|
||||
model = self.__warp_with_torch_autocast(model)
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __warp_with_torch_autocast(self, model):
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
for i in range(len(model)):
|
||||
model[i] = self.__warp_model_forward_with_torch_autocast(model[i])
|
||||
model = [self.__wrap_model_forward_with_torch_autocast(m) for m in model]
|
||||
else:
|
||||
model = self.__warp_model_forward_with_torch_autocast(model)
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __warp_model_forward_with_torch_autocast(self, model):
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
device_type= "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
with torch.autocast(device_type=device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
|
||||
Reference in New Issue
Block a user