mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into main
This commit is contained in:
@@ -45,6 +45,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
||||
|
||||
|
||||
class ImageInfo():
|
||||
@@ -87,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.enable_bucket = False
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.tag_frequency = {}
|
||||
self.bucket_info = None
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
@@ -115,6 +117,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
self.tag_frequency[dir_name] = frequency_for_dir
|
||||
for caption in captions:
|
||||
for tag in caption.split(","):
|
||||
if tag and not tag.isspace():
|
||||
tag = tag.lower()
|
||||
frequency = frequency_for_dir.get(tag, 0)
|
||||
frequency_for_dir[tag] = frequency + 1
|
||||
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
@@ -247,7 +259,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices: list(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.buckets):
|
||||
@@ -545,6 +556,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
cap_for_img = read_caption(img_path)
|
||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
||||
|
||||
return n_repeats, img_paths, captions
|
||||
|
||||
print("prepare train images.")
|
||||
@@ -553,10 +566,13 @@ class DreamBoothDataset(BaseDataset):
|
||||
for dir in train_dirs:
|
||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
||||
num_train_images += n_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||
self.register_image(info)
|
||||
|
||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
@@ -570,9 +586,11 @@ class DreamBoothDataset(BaseDataset):
|
||||
for dir in reg_dirs:
|
||||
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
||||
num_reg_images += n_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||
reg_infos.append(info)
|
||||
|
||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
|
||||
print(f"{num_reg_images} reg images.")
|
||||
@@ -617,6 +635,7 @@ class FineTuningDataset(BaseDataset):
|
||||
self.train_data_dir = train_data_dir
|
||||
self.batch_size = batch_size
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
if os.path.exists(image_key):
|
||||
@@ -633,6 +652,7 @@ class FineTuningDataset(BaseDataset):
|
||||
caption = tags
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
tags_list.append(tags)
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||
|
||||
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
||||
@@ -646,7 +666,8 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
# check existence of all npz files
|
||||
if not self.color_aug:
|
||||
@@ -667,6 +688,8 @@ class FineTuningDataset(BaseDataset):
|
||||
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
|
||||
elif not npz_all:
|
||||
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
||||
if self.flip_aug:
|
||||
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
@@ -756,15 +779,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
break
|
||||
|
||||
|
||||
def glob_images(dir, base):
|
||||
def glob_images(directory, base="*"):
|
||||
img_paths = []
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
if base == '*':
|
||||
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
|
||||
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||
else:
|
||||
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
|
||||
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||
# img_paths = list(set(img_paths)) # 重複を排除
|
||||
# img_paths.sort()
|
||||
return img_paths
|
||||
|
||||
|
||||
def glob_images_pathlib(dir_path, recursive):
|
||||
image_paths = []
|
||||
if recursive:
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
image_paths += list(dir_path.rglob('*' + ext))
|
||||
else:
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
image_paths += list(dir_path.glob('*' + ext))
|
||||
# image_paths = list(set(image_paths)) # 重複を排除
|
||||
# image_paths.sort()
|
||||
return image_paths
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -1497,5 +1535,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
|
||||
# endregion
|
||||
|
||||
# region 前処理用
|
||||
|
||||
|
||||
class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, image_paths):
|
||||
self.images = image_paths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.images[idx]
|
||||
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor_pil = transforms.functional.pil_to_tensor(image)
|
||||
except Exception as e:
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor_pil, img_path)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user