mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
initial version of TI
This commit is contained in:
@@ -104,9 +104,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
@@ -119,6 +124,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
def get_input_ids(self, caption):
|
||||
@@ -589,7 +605,7 @@ class FineTuningDataset(BaseDataset):
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(train_data_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
caption = img_md.get('caption')
|
||||
@@ -689,15 +705,17 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
def debug_dataset(train_dataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
k = 0
|
||||
for example in train_dataset:
|
||||
if example['latents'] is not None:
|
||||
print("sample has latents from npz file")
|
||||
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
|
||||
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example['images'] is not None:
|
||||
im = example['images'][j]
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
|
||||
Reference in New Issue
Block a user