mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
add highvram option and do not clear cache in caching latents
This commit is contained in:
@@ -33,6 +33,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -76,6 +77,8 @@ from library.original_unet import UNet2DConditionModel
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
HIGH_VRAM = False
|
||||
|
||||
# checkpointファイル名
|
||||
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
||||
EPOCH_FILE_NAME = "{}-{:06d}"
|
||||
@@ -2281,8 +2284,8 @@ def cache_batch_latents(
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
# FIXME this slows down caching a lot, specify this as an option
|
||||
clean_memory()
|
||||
if not HIGH_VRAM:
|
||||
clean_memory()
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
@@ -3037,7 +3040,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--lowram",
|
||||
action="store_true",
|
||||
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなど(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)",
|
||||
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込む等(ColabやKaggleなどRAMに比べてVRAMが多い環境向け)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highvram",
|
||||
action="store_true",
|
||||
help="disable low VRAM optimization. e.g. do not clear CUDA cache after each latent caching (for machines which have bigger VRAM) / VRAMが少ない環境向け最適化を無効にする。たとえば各latentのキャッシュ後のCUDAキャッシュを行わない等(VRAMが多い環境向け)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -3128,6 +3136,15 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
学習用引数を検証する。あわせて highvram オプションの指定をグローバル変数に反映する
|
||||
"""
|
||||
if args.highvram:
|
||||
print("highvram is enabled / highvramが有効です")
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
|
||||
Reference in New Issue
Block a user