add comments about debice for clarify

This commit is contained in:
Kohya S
2023-03-30 21:44:40 +09:00
parent 6c28dfb417
commit 31069e1dc5
2 changed files with 3 additions and 0 deletions

View File

@@ -2648,6 +2648,7 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
print("load StableDiffusion checkpoint") print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
else: else:
# Diffusers model is loaded to CPU
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
try: try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None) pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)

View File

@@ -128,6 +128,7 @@ def train(args):
# モデルを読み込む # モデルを読み込む
for pi in range(accelerator.state.num_processes): for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index: if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)
@@ -136,6 +137,7 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# work on low-ram device # work on low-ram device
# NOTE: this may not be necessary because we already load them on gpu
if args.lowram: if args.lowram:
text_encoder.to(accelerator.device) text_encoder.to(accelerator.device)
unet.to(accelerator.device) unet.to(accelerator.device)