add highvram option and do not clear cache in caching latents

This commit is contained in:
Kohya S
2024-02-01 21:55:55 +09:00
parent 9f0f0d573d
commit 5cca1fdc40

View File

@@ -33,6 +33,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -76,6 +77,8 @@ from library.original_unet import UNet2DConditionModel
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
HIGH_VRAM = False
# checkpointファイル名
EPOCH_STATE_NAME = "{}-{:06d}-state"
EPOCH_FILE_NAME = "{}-{:06d}"
@@ -2281,8 +2284,8 @@ def cache_batch_latents(
if flip_aug:
info.latents_flipped = flipped_latent
# FIXME this slows down caching a lot, specify this as an option
clean_memory()
if not HIGH_VRAM:
clean_memory()
def cache_batch_text_encoder_outputs(
@@ -3037,7 +3040,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--lowram",
action="store_true",
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなどColabやKaggleなどRAMに比べてVRAMが多い環境向け",
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むColabやKaggleなどRAMに比べてVRAMが多い環境向け",
)
parser.add_argument(
"--highvram",
action="store_true",
help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) / VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュを行わない等VRAMが多い環境向け",
)
parser.add_argument(
@@ -3128,6 +3136,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
"""
if args.highvram:
print("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None: