Refactor caching in train scripts

This commit is contained in:
kohya-ss
2024-10-12 20:18:41 +09:00
parent ff4083b910
commit c80c304779
14 changed files with 95 additions and 47 deletions

View File

@@ -31,6 +31,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@@ -1192,7 +1193,7 @@ class BaseDataset(torch.utils.data.Dataset):
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
"""
@@ -1207,15 +1208,25 @@ class BaseDataset(torch.utils.data.Dataset):
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents
# support multiple-gpus
num_processes = accelerator.num_processes
process_index = accelerator.process_index
logger.info("checking cache validity...")
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability/main process
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available or not is_main_process: # do not add to batch
if cache_available: # do not add to batch
continue
batch.append(info)
@@ -2420,6 +2431,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(model, accelerator)
accelerator.wait_for_everyone()
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2437,10 +2449,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process)
dataset.new_cache_text_encoder_outputs(models, accelerator)
accelerator.wait_for_everyone()
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -4210,6 +4223,12 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--skip_cache_check",
action="store_true",
help="skip the content validation of cache (latent and text encoder output). Cache file existence check is always performed, and cache processing is performed if the file does not exist"
" / cacheの内容の検証をスキップするlatentとテキストエンコーダの出力。キャッシュファイルの存在確認は常に行われ、ファイルがなければキャッシュ処理が行われる",
)
parser.add_argument(
"--enable_bucket",
action="store_true",
@@ -5084,15 +5103,24 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend = args.dynamo_backend
kwargs_handlers = [
InitProcessGroupKwargs(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) if torch.cuda.device_count() > 1 else None,
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
static_graph=args.ddp_static_graph
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
(
InitProcessGroupKwargs(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method=(
"env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None
),
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None,
)
if torch.cuda.device_count() > 1
else None
),
(
DistributedDataParallelKwargs(
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)