From 31069e1dc54c09dfedca15a27d22419a92f1ed8f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 30 Mar 2023 21:44:40 +0900 Subject: [PATCH] add comments about debice for clarify --- library/train_util.py | 1 + train_network.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 35c49bff..59dbc44c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2648,6 +2648,7 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'): print("load StableDiffusion checkpoint") text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) else: + # Diffusers model is loaded to CPU print("load Diffusers pretrained models") try: pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) diff --git a/train_network.py b/train_network.py index c34a2e5a..200d8d84 100644 --- a/train_network.py +++ b/train_network.py @@ -128,6 +128,7 @@ def train(args): # モデルを読み込む for pi in range(accelerator.state.num_processes): + # TODO: modify other training scripts as well if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device) @@ -136,6 +137,7 @@ def train(args): accelerator.wait_for_everyone() # work on low-ram device + # NOTE: this may not be necessary because we already load them on gpu if args.lowram: text_encoder.to(accelerator.device) unet.to(accelerator.device)