support multi gpu in caching text encoder outputs

This commit is contained in:
Kohya S
2023-07-09 16:02:56 +09:00
parent 3579b4570f
commit 0416f26a76
5 changed files with 32 additions and 22 deletions

View File

@@ -25,6 +25,9 @@ The feature of SDXL training is now available in sdxl branch as an experimental
Summary of the feature: Summary of the feature:
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset. - `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning. - `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`. - `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options: - Both scripts has following additional options:
@@ -64,6 +67,7 @@ learning_rate = 4e-7 # SDXL original learning rate
- [ ] Support Textual Inversion training. - [ ] Support Textual Inversion training.
- [ ] Support `--weighted_captions` option. - [ ] Support `--weighted_captions` option.
- [ ] Change `--output_config` option to continue the training. - [ ] Change `--output_config` option to continue the training.
- [ ] Extend `--full_bf16` for all the scripts.
## About requirements.txt ## About requirements.txt

View File

@@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
# TextEncoderの出力をキャッシュする # TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype): def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
print("caching text encoder outputs") print("caching text encoder outputs")
tokenizer1, tokenizer2 = tokenizers tokenizer1, tokenizer2 = tokenizers
@@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
text_encoder1_cache = {} text_encoder1_cache = {}
text_encoder2_cache = {} text_encoder2_cache = {}
for batch in tqdm(data_loader): for batch in tqdm(dataset):
input_ids1_batch = batch["input_ids"] input_ids1_batch = batch["input_ids"].to(accelerator.device)
input_ids2_batch = batch["input_ids2"] input_ids2_batch = batch["input_ids2"].to(accelerator.device)
# split batch to avoid OOM # split batch to avoid OOM
# TODO specify batch size by args # TODO specify batch size by args

View File

@@ -204,12 +204,25 @@ def train(args):
text_encoder2.gradient_checkpointing_enable() text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1) training_models.append(text_encoder1)
training_models.append(text_encoder2) training_models.append(text_encoder2)
text_encoder1_cache = None
text_encoder2_cache = None
# set require_grad=True later
else: else:
text_encoder1.requires_grad_(False) text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False) text_encoder2.requires_grad_(False)
text_encoder1.eval() text_encoder1.eval()
text_encoder2.eval() text_encoder2.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
)
accelerator.wait_for_everyone()
if not cache_latents: if not cache_latents:
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
@@ -289,23 +302,16 @@ def train(args):
(unet,) = train_util.transform_models_if_DDP([unet]) (unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype) text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype) text_encoder2.to(weight_dtype)
text_encoder1.eval()
text_encoder2.eval()
# TextEncoderの出力をキャッシュする # TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
)
accelerator.wait_for_everyone()
# Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32) text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
else: else:
text_encoder1_cache = None # make sure Text Encoders are on GPU
text_encoder2_cache = None
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device) text_encoder2.to(accelerator.device)

View File

@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return args.cache_text_encoder_outputs return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed( def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
): ):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
if not args.lowram: if not args.lowram:
@@ -61,7 +61,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs( text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
) )
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU

View File

@@ -255,6 +255,11 @@ class NetworkTrainer:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
# prepare network # prepare network
net_kwargs = {} net_kwargs = {}
if args.network_args is not None: if args.network_args is not None:
@@ -419,11 +424,6 @@ class NetworkTrainer:
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype
)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16: if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.patch_accelerator_for_fp16_training(accelerator)