mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
fix: str is not "no" to str != "no"
This commit is contained in:
@@ -128,7 +128,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
|
|
||||||
self.models = torch.nn.ModuleDict()
|
self.models = torch.nn.ModuleDict()
|
||||||
|
|
||||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
wrap_model_forward_with_torch_autocast = args.mixed_precision != "no"
|
||||||
|
|
||||||
for key, model in kw_models.items():
|
for key, model in kw_models.items():
|
||||||
if isinstance(model, list):
|
if isinstance(model, list):
|
||||||
@@ -166,7 +166,7 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
)
|
)
|
||||||
device_type = get_preferred_device().type
|
device_type = get_preferred_device().type
|
||||||
|
|
||||||
with torch.autocast(device_type = device_type):
|
with torch.autocast(device_type=device_type):
|
||||||
return forward_fn(*args, **kwargs)
|
return forward_fn(*args, **kwargs)
|
||||||
|
|
||||||
model.forward = forward
|
model.forward = forward
|
||||||
@@ -175,6 +175,5 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|||||||
def get_models(self):
|
def get_models(self):
|
||||||
return self.models
|
return self.models
|
||||||
|
|
||||||
|
|
||||||
ds_model = DeepSpeedWrapper(**models)
|
ds_model = DeepSpeedWrapper(**models)
|
||||||
return ds_model
|
return ds_model
|
||||||
|
|||||||
Reference in New Issue
Block a user