diff --git a/library/train_util.py b/library/train_util.py index 3b74ddce..21e0d191 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -33,6 +33,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -76,6 +77,8 @@ from library.original_unet import UNet2DConditionModel TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +HIGH_VRAM = False + # checkpointファイル名 EPOCH_STATE_NAME = "{}-{:06d}-state" EPOCH_FILE_NAME = "{}-{:06d}" @@ -2281,8 +2284,8 @@ def cache_batch_latents( if flip_aug: info.latents_flipped = flipped_latent - # FIXME this slows down caching a lot, specify this as an option - clean_memory() + if not HIGH_VRAM: + clean_memory() def cache_batch_text_encoder_outputs( @@ -3037,7 +3040,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--lowram", action="store_true", - help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)", + ) + parser.add_argument( + "--highvram", + action="store_true", + help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) / VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュを行わない等(VRAMが多い環境向け)", ) parser.add_argument( @@ -3128,6 +3136,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): + r""" + Verify training arguments. Also reflect highvram option to global variable + 学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する + """ + if args.highvram: + print("highvram is enabled / highvramが有効です") + global HIGH_VRAM + HIGH_VRAM = True + if args.v_parameterization and not args.v2: print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: