mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
skip npz check for multi node training
This commit is contained in:
@@ -24,8 +24,7 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
# DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
def load_target_model_backup(args, accelerator, model_version: str, weight_dtype):
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
@@ -61,6 +60,38 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
args.vae,
|
||||
model_version,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
model_dtype,
|
||||
)
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
logger.info(f"model loaded for process {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")
|
||||
accelerator.wait_for_everyone()
|
||||
logger.info(f"model loaded for all processes {accelerator.state.local_process_index} {accelerator.state.process_index} /{accelerator.state.num_processes}")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
def _load_target_model(
|
||||
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
||||
|
||||
@@ -146,6 +146,18 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def set_skip_npz_path_check(skip: bool):
|
||||
global SKIP_NPZ_PATH_CHECK
|
||||
SKIP_NPZ_PATH_CHECK = skip
|
||||
|
||||
def npz_path_exists(path):
|
||||
"""
|
||||
Check if the (cached latents) path exists. This is necessary for NFS systems.
|
||||
"""
|
||||
if SKIP_NPZ_PATH_CHECK:
|
||||
return True
|
||||
return os.path.exists(path)
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
sizes: List[Optional[Tuple[int, int]]],
|
||||
@@ -1392,7 +1404,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
if npz_path_exists(te_out_npz):
|
||||
# TODO check varidity of cache here
|
||||
continue
|
||||
|
||||
@@ -2168,7 +2180,7 @@ class FineTuningDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
if npz_path_exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
@@ -2182,12 +2194,13 @@ class FineTuningDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
#for image_key, img_md in metadata.items():
|
||||
for image_key, img_md in tqdm(metadata.items(), desc=f"load metadata: {subset.metadata_file}"):
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
if npz_path_exists(image_key):
|
||||
abs_path = image_key
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
@@ -2197,11 +2210,11 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
if npz_path_exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
if npz_path_exists(npz_path):
|
||||
abs_path = npz_path
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
@@ -2335,10 +2348,10 @@ class FineTuningDataset(BaseDataset):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
if npz_path_exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
if not npz_path_exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
@@ -2350,10 +2363,10 @@ class FineTuningDataset(BaseDataset):
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
if not npz_path_exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
elif not npz_path_exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
@@ -2662,7 +2675,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
if not npz_path_exists(npz_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -4586,6 +4599,12 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)",
|
||||
)
|
||||
|
||||
def add_skip_check_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--skip_npz_existence_check",
|
||||
action="store_true",
|
||||
help="skip check for images and latents existence, useful if your storage has low random access speed / 画像とlatentの存在チェックをスキップする。ストレージのランダムアクセス速度が遅い場合に有用",
|
||||
)
|
||||
|
||||
def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
if not args.config_file:
|
||||
@@ -5458,7 +5477,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
def load_target_model_backup(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
@@ -5479,6 +5498,24 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
accelerator.wait_for_everyone()
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||
args,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
logger.info(f"Model loaded for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}, {accelerator.process_index}")
|
||||
accelerator.wait_for_everyone()
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
@@ -6313,13 +6350,15 @@ def sample_images_common(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
image_paths = []
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
image_paths = [sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
)]
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
@@ -6330,9 +6369,32 @@ def sample_images_common(
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
image_paths += [sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)]
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
# if not main process, return
|
||||
if accelerator.is_main_process:
|
||||
try:
|
||||
import wandb
|
||||
logger.info(image_paths)
|
||||
wandb_logger = accelerator.get_tracker("wandb")
|
||||
# parse base filename without ext from first image path
|
||||
for image_path_saved in get_all_paths_like_imagepaths_by_time(image_paths[0]):
|
||||
# 0327_bs768_lion_highres_focus_fixxl4_000020_13_20240329061413_42
|
||||
# get 13
|
||||
file_basename = os.path.basename(image_path_saved).split(".")[0]
|
||||
sample_idx = int(file_basename.split("_")[-3])
|
||||
logger.info(f"sample_idx: {sample_idx} -> {image_path_saved}")
|
||||
wandb_logger.log(
|
||||
{f"sample_{sample_idx}" : wandb.Image(Image.open(image_path_saved))},
|
||||
commit=False,
|
||||
step=steps,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
pass
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
@@ -6344,6 +6406,22 @@ def sample_images_common(
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
def get_all_paths_like_imagepaths_by_time(image_path):
|
||||
file_basename = os.path.basename(image_path).split(".")[0]
|
||||
timestamp_str = file_basename.split("_")[-2]
|
||||
original_timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")
|
||||
|
||||
front_fixed_part = "_".join(file_basename.split("_")[:-3])
|
||||
|
||||
for root, dirs, files in os.walk(os.path.dirname(image_path)):
|
||||
for file in files:
|
||||
front_fixed_part = "_".join(file_basename.split("_")[:-3])
|
||||
if front_fixed_part in file:
|
||||
timestamp_str = file.split("_")[-2]
|
||||
timestamp = datetime.datetime.strptime(timestamp_str, "%Y%m%d%H%M%S")
|
||||
# allow 60-second difference
|
||||
if abs((timestamp - original_timestamp).total_seconds()) < 60:
|
||||
yield os.path.join(root, file)
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
@@ -6431,14 +6509,15 @@ def sample_image_inference(
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
return os.path.join(save_dir, img_filename)
|
||||
# send images to wandb if enabled
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
#if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
# wandb_tracker = accelerator.get_tracker("wandb")
|
||||
#
|
||||
# import wandb
|
||||
#
|
||||
# # not to commit images to avoid inconsistency between training and logging steps
|
||||
# wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
|
||||
Reference in New Issue
Block a user