mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix DDP issues and Support DDP for all training scripts (#448)
* Fix DDP bugs * Fix DDP bugs for finetune and db * refactor model loader * fix DDP network * try to fix DDP network in train unet only * remove unuse DDP import * refactor DDP transform * refactor DDP transform * fix sample images bugs * change DDP tranform location * add autocast to train_db * support DDP in XTI * Clear DDP import
This commit is contained in:
@@ -98,7 +98,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
@@ -280,6 +280,9 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
Reference in New Issue
Block a user