mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support multi gpu in caching text encoder outputs
This commit is contained in:
@@ -25,7 +25,10 @@ The feature of SDXL training is now available in sdxl branch as an experimental
|
|||||||
Summary of the feature:
|
Summary of the feature:
|
||||||
|
|
||||||
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
|
||||||
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
|
||||||
|
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
|
||||||
|
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
|
||||||
|
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
|
||||||
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
|
||||||
- Both scripts has following additional options:
|
- Both scripts has following additional options:
|
||||||
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
|
||||||
@@ -64,6 +67,7 @@ learning_rate = 4e-7 # SDXL original learning rate
|
|||||||
- [ ] Support Textual Inversion training.
|
- [ ] Support Textual Inversion training.
|
||||||
- [ ] Support `--weighted_captions` option.
|
- [ ] Support `--weighted_captions` option.
|
||||||
- [ ] Change `--output_config` option to continue the training.
|
- [ ] Change `--output_config` option to continue the training.
|
||||||
|
- [ ] Extend `--full_bf16` for all the scripts.
|
||||||
|
|
||||||
## About requirements.txt
|
## About requirements.txt
|
||||||
|
|
||||||
|
|||||||
@@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
|
|
||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュする
|
||||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
|
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
|
||||||
print("caching text encoder outputs")
|
print("caching text encoder outputs")
|
||||||
|
|
||||||
tokenizer1, tokenizer2 = tokenizers
|
tokenizer1, tokenizer2 = tokenizers
|
||||||
@@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
|
|||||||
|
|
||||||
text_encoder1_cache = {}
|
text_encoder1_cache = {}
|
||||||
text_encoder2_cache = {}
|
text_encoder2_cache = {}
|
||||||
for batch in tqdm(data_loader):
|
for batch in tqdm(dataset):
|
||||||
input_ids1_batch = batch["input_ids"]
|
input_ids1_batch = batch["input_ids"].to(accelerator.device)
|
||||||
input_ids2_batch = batch["input_ids2"]
|
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
|
||||||
|
|
||||||
# split batch to avoid OOM
|
# split batch to avoid OOM
|
||||||
# TODO specify batch size by args
|
# TODO specify batch size by args
|
||||||
|
|||||||
@@ -204,12 +204,25 @@ def train(args):
|
|||||||
text_encoder2.gradient_checkpointing_enable()
|
text_encoder2.gradient_checkpointing_enable()
|
||||||
training_models.append(text_encoder1)
|
training_models.append(text_encoder1)
|
||||||
training_models.append(text_encoder2)
|
training_models.append(text_encoder2)
|
||||||
|
|
||||||
|
text_encoder1_cache = None
|
||||||
|
text_encoder2_cache = None
|
||||||
|
|
||||||
|
# set require_grad=True later
|
||||||
else:
|
else:
|
||||||
text_encoder1.requires_grad_(False)
|
text_encoder1.requires_grad_(False)
|
||||||
text_encoder2.requires_grad_(False)
|
text_encoder2.requires_grad_(False)
|
||||||
text_encoder1.eval()
|
text_encoder1.eval()
|
||||||
text_encoder2.eval()
|
text_encoder2.eval()
|
||||||
|
|
||||||
|
# TextEncoderの出力をキャッシュする
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
# Text Encodes are eval and no grad
|
||||||
|
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||||
|
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
|
||||||
|
)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -289,23 +302,16 @@ def train(args):
|
|||||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||||
text_encoder1.to(weight_dtype)
|
text_encoder1.to(weight_dtype)
|
||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
text_encoder1.eval()
|
|
||||||
text_encoder2.eval()
|
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
|
||||||
)
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
# Text Encoder doesn't work on CPU with fp16
|
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
text_encoder1_cache = None
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder2_cache = None
|
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
text_encoder2.to(accelerator.device)
|
text_encoder2.to(accelerator.device)
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return args.cache_text_encoder_outputs
|
return args.cache_text_encoder_outputs
|
||||||
|
|
||||||
def cache_text_encoder_outputs_if_needed(
|
def cache_text_encoder_outputs_if_needed(
|
||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
|
||||||
):
|
):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
@@ -61,7 +61,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||||
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
|
|||||||
@@ -255,6 +255,11 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||||
|
self.cache_text_encoder_outputs_if_needed(
|
||||||
|
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
# prepare network
|
# prepare network
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args is not None:
|
if args.network_args is not None:
|
||||||
@@ -419,11 +424,6 @@ class NetworkTrainer:
|
|||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
|
||||||
self.cache_text_encoder_outputs_if_needed(
|
|
||||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|||||||
Reference in New Issue
Block a user