skip npz check for multi node training

This commit is contained in:
Darren Lau
2025-03-05 14:30:18 +08:00
parent 0de1e00f0d
commit b18f20fc00
3 changed files with 142 additions and 24 deletions

View File

@@ -146,6 +146,18 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def set_skip_npz_path_check(skip: bool):
global SKIP_NPZ_PATH_CHECK
SKIP_NPZ_PATH_CHECK = skip
def npz_path_exists(path):
"""
Check if the (cached latents) path exists. This is necessary for NFS systems.
"""
if SKIP_NPZ_PATH_CHECK:
return True
return os.path.exists(path)
def split_train_val(
paths: List[str],
sizes: List[Optional[Tuple[int, int]]],
@@ -1392,7 +1404,7 @@ class BaseDataset(torch.utils.data.Dataset):
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
if npz_path_exists(te_out_npz):
# TODO check varidity of cache here
continue
@@ -2168,7 +2180,7 @@ class FineTuningDataset(BaseDataset):
continue
# メタデータを読み込む
if os.path.exists(subset.metadata_file):
if npz_path_exists(subset.metadata_file):
logger.info(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
@@ -2182,12 +2194,13 @@ class FineTuningDataset(BaseDataset):
continue
tags_list = []
for image_key, img_md in metadata.items():
#for image_key, img_md in metadata.items():
for image_key, img_md in tqdm(metadata.items(), desc=f"load metadata: {subset.metadata_file}"):
# path情報を作る
abs_path = None
# まず画像を優先して探す
if os.path.exists(image_key):
if npz_path_exists(image_key):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
@@ -2197,11 +2210,11 @@ class FineTuningDataset(BaseDataset):
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
if npz_path_exists(npz_path):
abs_path = npz_path
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
@@ -2335,10 +2348,10 @@ class FineTuningDataset(BaseDataset):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
if os.path.exists(npz_file_norm):
if npz_path_exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
if not npz_path_exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
@@ -2350,10 +2363,10 @@ class FineTuningDataset(BaseDataset):
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
if not npz_path_exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
elif not npz_path_exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
@@ -2662,7 +2675,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
if not npz_path_exists(npz_path):
return False
try:
@@ -4586,6 +4599,12 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時",
)
def add_skip_check_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--skip_npz_existence_check",
action="store_true",
help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用",
)
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
if not args.config_file:
@@ -5458,7 +5477,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
return text_encoder, vae, unet, load_stable_diffusion_format
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
def load_target_model_backup(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
@@ -5479,6 +5498,24 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
weight_dtype,
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
clean_memory_on_device(accelerator.device)
logger.info(f"Model loaded for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_
@@ -6313,13 +6350,15 @@ def sample_images_common(
except Exception:
pass
image_paths = []
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
image_paths = [sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
)]
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
@@ -6330,9 +6369,32 @@ def sample_images_common(
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
image_paths += [sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)]
accelerator.wait_for_everyone()
# if not main process, return
if accelerator.is_main_process:
try:
import wandb
logger.info(image_paths)
wandb_logger = accelerator.get_tracker("wandb")
# parse base filename without ext from first image path
for image_path_saved in get_all_paths_like_imagepaths_by_time(image_paths[0]):
# 0327_bs768_lion_highres_focus_fixxl4_000020_13_20240329061413_42
# get 13
file_basename = os.path.basename(image_path_saved).split(".")[0]
sample_idx = int(file_basename.split("_")[-3])
logger.info(f"sample_idx: {sample_idx} -> {image_path_saved}")
wandb_logger.log(
{f"sample_{sample_idx}" : wandb.Image(Image.open(image_path_saved))},
commit=False,
step=steps,
)
except Exception as e:
logger.warning(e)
pass
# clear pipeline and cache to reduce vram usage
del pipeline
@@ -6344,6 +6406,22 @@ def sample_images_common(
clean_memory_on_device(accelerator.device)
def get_all_paths_like_imagepaths_by_time(image_path):
file_basename = os.path.basename(image_path).split(".")[0]
timestamp_str = file_basename.split("_")[-2]
original_timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")
front_fixed_part = "_".join(file_basename.split("_")[:-3])
for root, dirs, files in os.walk(os.path.dirname(image_path)):
for file in files:
front_fixed_part = "_".join(file_basename.split("_")[:-3])
if front_fixed_part in file:
timestamp_str = file.split("_")[-2]
timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")
# allow 60-second difference
if abs((timestamp - original_timestamp).total_seconds()) < 60:
yield os.path.join(root, file)
def sample_image_inference(
accelerator: Accelerator,
@@ -6431,14 +6509,15 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
return os.path.join(save_dir, img_filename)
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
#if "wandb" in [tracker.name for tracker in accelerator.trackers]:
# wandb_tracker = accelerator.get_tracker("wandb")
#
# import wandb
#
# # not to commit images to avoid inconsistency between training and logging steps
# wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):