mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching in train scripts
This commit is contained in:
@@ -31,6 +31,7 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
r"""
|
||||
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||
"""
|
||||
@@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
# support multiple-gpus
|
||||
num_processes = accelerator.num_processes
|
||||
process_index = accelerator.process_index
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
# check disk cache exists and size of text encoder outputs
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different text encoder outputs
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available or not is_main_process: # do not add to batch
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
@@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(model, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
@@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
dataset.new_cache_text_encoder_outputs(models, accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
@@ -4210,6 +4223,12 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_cache_check",
|
||||
action="store_true",
|
||||
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
|
||||
" / cacheの内容の検証をスキップする(latentとテキストエンコーダの出力)。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_bucket",
|
||||
action="store_true",
|
||||
@@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
kwargs_handlers = [
|
||||
InitProcessGroupKwargs(
|
||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
|
||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
||||
) if torch.cuda.device_count() > 1 else None,
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
|
||||
static_graph=args.ddp_static_graph
|
||||
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
|
||||
(
|
||||
InitProcessGroupKwargs(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method=(
|
||||
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
|
||||
),
|
||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
|
||||
)
|
||||
if torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
(
|
||||
DistributedDataParallelKwargs(
|
||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
||||
)
|
||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||
else None
|
||||
),
|
||||
]
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
Reference in New Issue
Block a user