From 0416f26a76c39911afe7aae09f2e628e02922a1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jul 2023 16:02:56 +0900 Subject: [PATCH] support multi gpu in caching text encoder outputs --- README.md | 6 +++++- library/sdxl_train_util.py | 8 ++++---- sdxl_train.py | 26 ++++++++++++++++---------- sdxl_train_network.py | 4 ++-- train_network.py | 10 +++++----- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 503427a5..8a47ee3b 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,10 @@ The feature of SDXL training is now available in sdxl branch as an experimental Summary of the feature: - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - - `prepare_buckets_latents.py` now supports SDXL fine-tuning. + - `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage. + - However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer. + - I cannot find bitsandbytes>0.35.0 that works correctly on Windows. +- `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - Both scripts has following additional options: - `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions. @@ -64,6 +67,7 @@ learning_rate = 4e-7 # SDXL original learning rate - [ ] Support Textual Inversion training. - [ ] Support `--weighted_captions` option. - [ ] Change `--output_config` option to continue the training. +- [ ] Extend `--full_bf16` for all the scripts. ## About requirements.txt diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index c67a7043..675aac3d 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise( # TextEncoderの出力をキャッシュする # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる -def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype): +def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype): print("caching text encoder outputs") tokenizer1, tokenizer2 = tokenizers @@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat text_encoder1_cache = {} text_encoder2_cache = {} - for batch in tqdm(data_loader): - input_ids1_batch = batch["input_ids"] - input_ids2_batch = batch["input_ids2"] + for batch in tqdm(dataset): + input_ids1_batch = batch["input_ids"].to(accelerator.device) + input_ids2_batch = batch["input_ids2"].to(accelerator.device) # split batch to avoid OOM # TODO specify batch size by args diff --git a/sdxl_train.py b/sdxl_train.py index 06cbc571..dd5b74dd 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -204,12 +204,25 @@ def train(args): text_encoder2.gradient_checkpointing_enable() training_models.append(text_encoder1) training_models.append(text_encoder2) + + text_encoder1_cache = None + text_encoder2_cache = None + + # set require_grad=True later else: text_encoder1.requires_grad_(False) text_encoder2.requires_grad_(False) text_encoder1.eval() text_encoder2.eval() + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( + args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None + ) + accelerator.wait_for_everyone() + if not cache_latents: vae.requires_grad_(False) vae.eval() @@ -289,23 +302,16 @@ def train(args): (unet,) = train_util.transform_models_if_DDP([unet]) text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) - text_encoder1.eval() - text_encoder2.eval() - # TextEncoderの出力をキャッシュする + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: - text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None - ) - accelerator.wait_for_everyone() - # Text Encoder doesn't work on CPU with fp16 + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) if torch.cuda.is_available(): torch.cuda.empty_cache() else: - text_encoder1_cache = None - text_encoder2_cache = None + # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) text_encoder2.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index ec15ce4b..0c3c0cc5 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return args.cache_text_encoder_outputs def cache_text_encoder_outputs_if_needed( - self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype + self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype ): if args.cache_text_encoder_outputs: if not args.lowram: @@ -61,7 +61,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): torch.cuda.empty_cache() text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( - args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype + args, accelerator, tokenizers, text_encoders, dataset, weight_dtype ) accelerator.wait_for_everyone() text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU diff --git a/train_network.py b/train_network.py index 3c9515b5..f7ee451b 100644 --- a/train_network.py +++ b/train_network.py @@ -255,6 +255,11 @@ class NetworkTrainer: accelerator.wait_for_everyone() + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される + self.cache_text_encoder_outputs_if_needed( + args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype + ) + # prepare network net_kwargs = {} if args.network_args is not None: @@ -419,11 +424,6 @@ class NetworkTrainer: vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される - self.cache_text_encoder_outputs_if_needed( - args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype - ) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator)