Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Jakaline-dev
2023-03-30 01:35:38 +09:00
11 changed files with 3050 additions and 2526 deletions

View File

@@ -276,6 +276,8 @@ class BaseSubset:
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
@@ -289,6 +291,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0
@@ -309,6 +314,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -324,6 +331,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
@@ -351,6 +360,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -366,6 +377,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.metadata_file = metadata_file
@@ -406,6 +419,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation
self.aug_helper = AugHelper()
@@ -421,9 +438,19 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {}
def set_seed(self, seed):
self.seed = seed
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
@@ -458,7 +485,16 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
@@ -470,10 +506,10 @@ class BaseDataset(torch.utils.data.Dataset):
return l
fixed_tokens = []
flex_tokens = [t.strip() for t in caption.strip().split(",")]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = flex_tokens[subset.keep_tokens :]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
@@ -643,6 +679,9 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
@@ -1062,7 +1101,7 @@ class DreamBoothDataset(BaseDataset):
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
@@ -1123,6 +1162,8 @@ class FineTuningDataset(BaseDataset):
# path情報を作る
if os.path.exists(image_key):
abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
@@ -1308,6 +1349,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
@@ -1315,37 +1364,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0
indices = list(range(len(train_dataset)))
random.shuffle(indices)
for i, idx in enumerate(indices):
example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids:
print(f"input ids: {iid}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == 27 or (example["images"] is None and i >= 8):
epoch = 1
while True:
print(f"epoch: {epoch}")
steps = (epoch - 1) * len(train_dataset) + 1
indices = list(range(len(train_dataset)))
random.shuffle(indices)
k = 0
for i, idx in enumerate(indices):
train_dataset.set_current_epoch(epoch)
train_dataset.set_current_step(steps)
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids:
print(f"input ids: {iid}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break
epoch += 1
def glob_images(directory, base="*"):
img_paths = []
@@ -1354,8 +1421,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
img_paths = list(set(img_paths)) # 重複を排除
img_paths.sort()
return img_paths
@@ -1367,8 +1434,8 @@ def glob_images_pathlib(dir_path, recursive):
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext))
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
image_paths = list(set(image_paths)) # 重複を排除
image_paths.sort()
return image_paths
@@ -2061,6 +2128,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
@@ -2995,3 +3076,24 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]