Merge pull request #332 from guaneec/ddp-lowram

Reduce peak RAM usage
This commit is contained in:
Kohya S
2023-03-30 21:37:37 +09:00
committed by GitHub
3 changed files with 19 additions and 17 deletions

View File

@@ -127,12 +127,18 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)