mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix images are used twice, update debug dataset
This commit is contained in:
@@ -481,8 +481,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
print(subset.token_warmup_min, subset.token_warmup_step)
|
||||
if subset.token_warmup_step < 1:
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
@@ -1342,50 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します")
|
||||
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
|
||||
|
||||
epoch = 1
|
||||
steps = 1
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
train_dataset.set_current_step(steps)
|
||||
while True:
|
||||
print(f"epoch: {epoch}")
|
||||
|
||||
k = 0
|
||||
indices = list(range(len(train_dataset)))
|
||||
random.shuffle(indices)
|
||||
for i, idx in enumerate(indices):
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid) in enumerate(
|
||||
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
|
||||
):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
print(f"image size: {im.size()}")
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27:
|
||||
break
|
||||
if k == ord("e"):
|
||||
epoch += 1
|
||||
steps = len(train_dataset) * (epoch - 1)
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
print(f"epoch: {epoch}")
|
||||
|
||||
steps += 1
|
||||
train_dataset.set_current_step(steps)
|
||||
steps = (epoch - 1) * len(train_dataset) + 1
|
||||
indices = list(range(len(train_dataset)))
|
||||
random.shuffle(indices)
|
||||
|
||||
if k == 27 or (example["images"] is None and i >= 8):
|
||||
k = 0
|
||||
for i, idx in enumerate(indices):
|
||||
train_dataset.set_current_epoch(epoch)
|
||||
train_dataset.set_current_step(steps)
|
||||
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
|
||||
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid) in enumerate(
|
||||
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
|
||||
):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
print(f"image size: {im.size()}")
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||
if os.name == "nt": # only windows
|
||||
cv2.imshow("img", im)
|
||||
k = cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
if k == 27 or k == ord("s") or k == ord("e"):
|
||||
break
|
||||
steps += 1
|
||||
|
||||
if k == ord("e"):
|
||||
break
|
||||
if k == 27 or (example["images"] is None and i >= 8):
|
||||
k = 27
|
||||
break
|
||||
if k == 27:
|
||||
break
|
||||
|
||||
epoch += 1
|
||||
|
||||
|
||||
def glob_images(directory, base="*"):
|
||||
img_paths = []
|
||||
@@ -1394,8 +1398,8 @@ def glob_images(directory, base="*"):
|
||||
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
||||
else:
|
||||
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
||||
# img_paths = list(set(img_paths)) # 重複を排除
|
||||
# img_paths.sort()
|
||||
img_paths = list(set(img_paths)) # 重複を排除
|
||||
img_paths.sort()
|
||||
return img_paths
|
||||
|
||||
|
||||
@@ -1407,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
else:
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
image_paths += list(dir_path.glob("*" + ext))
|
||||
# image_paths = list(set(image_paths)) # 重複を排除
|
||||
# image_paths.sort()
|
||||
image_paths = list(set(image_paths)) # 重複を排除
|
||||
image_paths.sort()
|
||||
return image_paths
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user