Refactor caching in train scripts

This commit is contained in:
kohya-ss
2024-10-12 20:18:41 +09:00
parent ff4083b910
commit c80c304779
14 changed files with 95 additions and 47 deletions

View File

@@ -11,6 +11,16 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Oct 12, 2024 (update 1):
- During multi-GPU training, caching of latents and Text Encoder outputs is now done in multi-GPU.
- `--text_encoder_batch_size` option is enabled for FLUX.1 LoRA training and fine tuning. This option specifies the batch size for caching Text Encoder outputs (not for training). The default is same as the dataset batch size. If you have enough VRAM, you can increase the batch size to speed up the caching.
- `--skip_cache_check` option is added to each training script.
- When specified, the consistency check of the cache file `*.npz` contents (e.g., image size and flip for latents, mask for Text Encoder outputs) is skipped.
- Specify this option if you have a large number of cache files and the consistency check takes time.
- Even if this option is specified, the cache will be created if the file does not exist.
- `--skip_latents_validity_check` in SD3/FLUX.1 is deprecated. Please use `--skip_cache_check` instead.
Oct 12, 2024: Oct 12, 2024:
- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)! - Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!

View File

@@ -59,7 +59,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents: if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args) deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True) setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert ( # assert (
# not args.weighted_captions # not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -81,7 +85,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents: if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy( latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -142,7 +146,7 @@ def train(args):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy( strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
) )
) )
t5xxl_max_token_length = ( t5xxl_max_token_length = (
@@ -181,7 +185,7 @@ def train(args):
# load VAE for caching latents # load VAE for caching latents
ae = None ae = None
if cache_latents: if cache_latents:
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors) ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype) ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False) ae.requires_grad_(False)
ae.eval() ae.eval()
@@ -229,7 +233,7 @@ def train(args):
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast(): with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator.is_main_process) train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory # cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None: if args.sample_prompts is not None:
@@ -952,7 +956,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--skip_latents_validity_check", "--skip_latents_validity_check",
action="store_true", action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
) )
parser.add_argument( parser.add_argument(
"--blocks_to_swap", "--blocks_to_swap",

View File

@@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if the text encoders is trained, we need tokenization, so is_partial is True # if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy( return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.cache_text_encoder_outputs_to_disk,
None, args.text_encoder_batch_size,
False, args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl, is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask, apply_t5_attn_mask=args.apply_t5_attn_mask,
) )
@@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[1].to(weight_dtype) text_encoders[1].to(weight_dtype)
with accelerator.autocast(): with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts # cache sample prompts
if args.sample_prompts is not None: if args.sample_prompts is not None:

View File

@@ -31,6 +31,7 @@ import hashlib
import subprocess import subprocess
from io import BytesIO from io import BytesIO
import toml import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed # from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm from tqdm import tqdm
@@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset):
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
r""" r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy. a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
""" """
@@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset):
# split by resolution # split by resolution
batches = [] batches = []
batch = [] batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents # support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk: if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available or not is_main_process: # do not add to batch if cache_available: # do not add to batch
continue continue
batch.append(info) batch.append(info)
@@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, accelerator) dataset.new_cache_latents(model, accelerator)
accelerator.wait_for_everyone()
def cache_text_encoder_outputs( def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
) )
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process) dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def set_caching_mode(self, caching_mode): def set_caching_mode(self, caching_mode):
for dataset in self.datasets: for dataset in self.datasets:
@@ -4210,6 +4223,12 @@ def add_dataset_arguments(
action="store_true", action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可", help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
) )
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument( parser.add_argument(
"--enable_bucket", "--enable_bucket",
action="store_true", action="store_true",
@@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend = args.dynamo_backend dynamo_backend = args.dynamo_backend
kwargs_handlers = [ kwargs_handlers = [
InitProcessGroupKwargs( (
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", InitProcessGroupKwargs(
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None init_method=(
) if torch.cuda.device_count() > 1 else None, "env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
DistributedDataParallelKwargs( ),
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
static_graph=args.ddp_static_graph )
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
] ]
kwargs_handlers = [i for i in kwargs_handlers if i is not None] kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

View File

@@ -57,6 +57,10 @@ def train(args):
deepspeed_utils.prepare_deepspeed_args(args) deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True) setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
assert ( assert (
not args.weighted_captions not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -103,7 +107,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents: if args.cache_latents:
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy( latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -312,7 +316,7 @@ def train(args):
text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy( text_encoder_caching_strategy = strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size, args.text_encoder_batch_size,
False, args.skip_cache_check,
train_clip_g or train_clip_l or args.use_t5xxl_cache_only, train_clip_g or train_clip_l or args.use_t5xxl_cache_only,
args.apply_lg_attn_mask, args.apply_lg_attn_mask,
args.apply_t5_attn_mask, args.apply_t5_attn_mask,
@@ -325,7 +329,7 @@ def train(args):
t5xxl.to(t5xxl_device, dtype=t5xxl_dtype) t5xxl.to(t5xxl_device, dtype=t5xxl_dtype)
with accelerator.autocast(): with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator.is_main_process) train_dataset_group.new_cache_text_encoder_outputs([clip_l, clip_g, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory # cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None: if args.sample_prompts is not None:
@@ -1052,7 +1056,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--skip_latents_validity_check", "--skip_latents_validity_check",
action="store_true", action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする", help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip cache (latents and text encoder outputs) check / キャッシュlatentsとtext encoder outputsのチェックをスキップする",
) )
parser.add_argument( parser.add_argument(
"--num_last_block_to_freeze", "--num_last_block_to_freeze",

View File

@@ -131,7 +131,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents: if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -328,7 +328,7 @@ def train(args):
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device) text_encoder2.to(accelerator.device)
with accelerator.autocast(): with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -84,7 +84,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -230,7 +230,7 @@ def train(args):
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device) text_encoder2.to(accelerator.device)
with accelerator.autocast(): with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()

View File

@@ -93,7 +93,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -202,7 +202,7 @@ def train(args):
text_encoder1.to(accelerator.device) text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device) text_encoder2.to(accelerator.device)
with accelerator.autocast(): with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process) train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -431,7 +431,6 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents) latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached # Text Encoder outputs are cached

View File

@@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_latents_caching_strategy(self, args): def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
return latents_caching_strategy return latents_caching_strategy
@@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args): def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
) )
else: else:
return None return None
@@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype) text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast(): with accelerator.autocast():
dataset.new_cache_text_encoder_outputs( dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU

View File

@@ -49,7 +49,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def get_latents_caching_strategy(self, args): def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
return latents_caching_strategy return latents_caching_strategy

View File

@@ -64,7 +64,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

View File

@@ -116,7 +116,7 @@ class NetworkTrainer:
def get_latents_caching_strategy(self, args): def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
return latents_caching_strategy return latents_caching_strategy

View File

@@ -114,7 +114,7 @@ class TextualInversionTrainer:
def get_latents_caching_strategy(self, args): def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
) )
return latents_caching_strategy return latents_caching_strategy