mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix images are used twice, update debug dataset
This commit is contained in:
@@ -481,8 +481,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||||
print(subset.token_warmup_min, subset.token_warmup_step)
|
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||||
if subset.token_warmup_step < 1:
|
|
||||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||||
tokens_len = (
|
tokens_len = (
|
||||||
@@ -1342,17 +1341,22 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
|
|
||||||
def debug_dataset(train_dataset, show_input_ids=False):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
|
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
|
||||||
|
|
||||||
epoch = 1
|
epoch = 1
|
||||||
steps = 1
|
while True:
|
||||||
train_dataset.set_current_epoch(epoch)
|
print(f"epoch: {epoch}")
|
||||||
train_dataset.set_current_step(steps)
|
|
||||||
|
|
||||||
k = 0
|
steps = (epoch - 1) * len(train_dataset) + 1
|
||||||
indices = list(range(len(train_dataset)))
|
indices = list(range(len(train_dataset)))
|
||||||
random.shuffle(indices)
|
random.shuffle(indices)
|
||||||
|
|
||||||
|
k = 0
|
||||||
for i, idx in enumerate(indices):
|
for i, idx in enumerate(indices):
|
||||||
|
train_dataset.set_current_epoch(epoch)
|
||||||
|
train_dataset.set_current_step(steps)
|
||||||
|
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
|
||||||
|
|
||||||
example = train_dataset[idx]
|
example = train_dataset[idx]
|
||||||
if example["latents"] is not None:
|
if example["latents"] is not None:
|
||||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||||
@@ -1372,19 +1376,19 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
k = cv2.waitKey()
|
k = cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
if k == 27 or k == ord("s") or k == ord("e"):
|
||||||
|
break
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
if k == ord("e"):
|
||||||
|
break
|
||||||
|
if k == 27 or (example["images"] is None and i >= 8):
|
||||||
|
k = 27
|
||||||
|
break
|
||||||
if k == 27:
|
if k == 27:
|
||||||
break
|
break
|
||||||
if k == ord("e"):
|
|
||||||
epoch += 1
|
epoch += 1
|
||||||
steps = len(train_dataset) * (epoch - 1)
|
|
||||||
train_dataset.set_current_epoch(epoch)
|
|
||||||
print(f"epoch: {epoch}")
|
|
||||||
|
|
||||||
steps += 1
|
|
||||||
train_dataset.set_current_step(steps)
|
|
||||||
|
|
||||||
if k == 27 or (example["images"] is None and i >= 8):
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def glob_images(directory, base="*"):
|
def glob_images(directory, base="*"):
|
||||||
@@ -1394,8 +1398,8 @@ def glob_images(directory, base="*"):
|
|||||||
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||||
else:
|
else:
|
||||||
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||||
# img_paths = list(set(img_paths)) # 重複を排除
|
img_paths = list(set(img_paths)) # 重複を排除
|
||||||
# img_paths.sort()
|
img_paths.sort()
|
||||||
return img_paths
|
return img_paths
|
||||||
|
|
||||||
|
|
||||||
@@ -1407,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive):
|
|||||||
else:
|
else:
|
||||||
for ext in IMAGE_EXTENSIONS:
|
for ext in IMAGE_EXTENSIONS:
|
||||||
image_paths += list(dir_path.glob("*" + ext))
|
image_paths += list(dir_path.glob("*" + ext))
|
||||||
# image_paths = list(set(image_paths)) # 重複を排除
|
image_paths = list(set(image_paths)) # 重複を排除
|
||||||
# image_paths.sort()
|
image_paths.sort()
|
||||||
return image_paths
|
return image_paths
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user