diff --git a/library/train_util.py b/library/train_util.py index 2d93b126..55b5101b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -481,8 +481,7 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: tokens = [t.strip() for t in caption.strip().split(",")] - print(subset.token_warmup_min, subset.token_warmup_step) - if subset.token_warmup_step < 1: + if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: tokens_len = ( @@ -1342,50 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset): def debug_dataset(train_dataset, show_input_ids=False): print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します") + print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") epoch = 1 - steps = 1 - train_dataset.set_current_epoch(epoch) - train_dataset.set_current_step(steps) + while True: + print(f"epoch: {epoch}") - k = 0 - indices = list(range(len(train_dataset))) - random.shuffle(indices) - for i, idx in enumerate(indices): - example = train_dataset[idx] - if example["latents"] is not None: - print(f"sample has latents from npz file: {example['latents'].size()}") - for j, (ik, cap, lw, iid) in enumerate( - zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) - ): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') - if show_input_ids: - print(f"input ids: {iid}") - if example["images"] is not None: - im = example["images"][j] - print(f"image size: {im.size()}") - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - if os.name == "nt": # only windows - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == ord("e"): - epoch += 1 - steps = len(train_dataset) * (epoch - 1) - train_dataset.set_current_epoch(epoch) - print(f"epoch: {epoch}") - - steps += 1 - train_dataset.set_current_step(steps) + steps = (epoch - 1) * len(train_dataset) + 1 + indices = list(range(len(train_dataset))) + random.shuffle(indices) - if k == 27 or (example["images"] is None and i >= 8): + k = 0 + for i, idx in enumerate(indices): + train_dataset.set_current_epoch(epoch) + train_dataset.set_current_step(steps) + print(f"steps: {steps} ({i + 1}/{len(train_dataset)})") + + example = train_dataset[idx] + if example["latents"] is not None: + print(f"sample has latents from npz file: {example['latents'].size()}") + for j, (ik, cap, lw, iid) in enumerate( + zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"]) + ): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') + if show_input_ids: + print(f"input ids: {iid}") + if example["images"] is not None: + im = example["images"][j] + print(f"image size: {im.size()}") + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + if os.name == "nt": # only windows + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27 or k == ord("s") or k == ord("e"): + break + steps += 1 + + if k == ord("e"): + break + if k == 27 or (example["images"] is None and i >= 8): + k = 27 + break + if k == 27: break + epoch += 1 + def glob_images(directory, base="*"): img_paths = [] @@ -1394,8 +1398,8 @@ def glob_images(directory, base="*"): img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) - # img_paths = list(set(img_paths)) # 重複を排除 - # img_paths.sort() + img_paths = list(set(img_paths)) # 重複を排除 + img_paths.sort() return img_paths @@ -1407,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive): else: for ext in IMAGE_EXTENSIONS: image_paths += list(dir_path.glob("*" + ext)) - # image_paths = list(set(image_paths)) # 重複を排除 - # image_paths.sort() + image_paths = list(set(image_paths)) # 重複を排除 + image_paths.sort() return image_paths