Merge pull request #332 from guaneec/ddp-lowram

Reduce peak RAM usage
This commit is contained in:
Kohya S
2023-03-30 21:37:37 +09:00
committed by GitHub
3 changed files with 19 additions and 17 deletions

View File

@@ -2640,13 +2640,13 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype):
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
else:
print("load Diffusers pretrained models")
try: