add comments about debice for clarify

This commit is contained in:
Kohya S
2023-03-30 21:44:40 +09:00
parent 6c28dfb417
commit 31069e1dc5
2 changed files with 3 additions and 0 deletions

View File

@@ -128,6 +128,7 @@ def train(args):
# モデルを読み込む
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)
@@ -136,6 +137,7 @@ def train(args):
accelerator.wait_for_everyone()
# work on low-ram device
# NOTE: this may not be necessary because we already load them on gpu
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)