mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support multi gpu in caching text encoder outputs
This commit is contained in:
@@ -25,7 +25,10 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
||||
Summary of the feature:
|
||||
|
||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
||||
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||
- Both scripts has following additional options:
|
||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||
@@ -64,6 +67,7 @@ learning_rate = 4e-7 # SDXL original learning rate
|
||||
- [ ] Support Textual Inversion training.
|
||||
- [ ] Support `--weighted_captions` option.
|
||||
- [ ] Change `--output_config` option to continue the training.
|
||||
- [ ] Extend `--full_bf16` for all the scripts.
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
|
||||
print("caching text encoder outputs")
|
||||
|
||||
tokenizer1, tokenizer2 = tokenizers
|
||||
@@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
|
||||
|
||||
text_encoder1_cache = {}
|
||||
text_encoder2_cache = {}
|
||||
for batch in tqdm(data_loader):
|
||||
input_ids1_batch = batch["input_ids"]
|
||||
input_ids2_batch = batch["input_ids2"]
|
||||
for batch in tqdm(dataset):
|
||||
input_ids1_batch = batch["input_ids"].to(accelerator.device)
|
||||
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
|
||||
|
||||
# split batch to avoid OOM
|
||||
# TODO specify batch size by args
|
||||
|
||||
@@ -204,12 +204,25 @@ def train(args):
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder1)
|
||||
training_models.append(text_encoder2)
|
||||
|
||||
text_encoder1_cache = None
|
||||
text_encoder2_cache = None
|
||||
|
||||
# set require_grad=True later
|
||||
else:
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder2.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
@@ -289,23 +302,16 @@ def train(args):
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
# Text Encoder doesn't work on CPU with fp16
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
text_encoder1_cache = None
|
||||
text_encoder2_cache = None
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
return args.cache_text_encoder_outputs
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
@@ -61,7 +61,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
|
||||
@@ -255,6 +255,11 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
@@ -419,11 +424,6 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
Reference in New Issue
Block a user