initial version of TI

This commit is contained in:
Kohya S
2023-01-12 20:47:08 +09:00
parent f981dfd38a
commit c1b14fcdd6
2 changed files with 513 additions and 3 deletions

View File

@@ -104,9 +104,14 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_data: dict[str, ImageInfo] = {}
self.replacements = {}
def disable_token_padding(self):
self.token_padding_disabled = True
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, caption):
if self.shuffle_caption:
tokens = caption.strip().split(",")
@@ -119,6 +124,17 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ",".join(tokens).strip()
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
def get_input_ids(self, caption):
@@ -589,7 +605,7 @@ class FineTuningDataset(BaseDataset):
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(train_data_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]
caption = img_md.get('caption')
@@ -689,15 +705,17 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
def debug_dataset(train_dataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
k = 0
for example in train_dataset:
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
if show_input_ids:
print(f"input ids: {iid}")
if example['images'] is not None:
im = example['images'][j]
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)