From 2aa27b7a4b1389775e41581d6a25dbffad4225e0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 Feb 2023 20:52:24 +0900 Subject: [PATCH] Update downsampling for larger image in no_upscale --- library/train_util.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c4f697ab..6f809deb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -123,6 +123,10 @@ class BucketManager(): self.buckets.append([]) # print(reso, bucket_id, len(self.buckets)) + def round_to_steps(self, x): + x = int(x + .5) + return x - x % self.reso_steps + def select_bucket(self, image_width, image_height): aspect_ratio = image_width / image_height if not self.no_upscale: @@ -150,7 +154,24 @@ class BucketManager(): resized_height = self.max_area / resized_width assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" - resized_size = (int(resized_width + .5), int(resized_height + .5)) + # リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ + # 元のbucketingと同じロジック + b_width_rounded = self.round_to_steps(resized_width) + b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio) + ar_width_rounded = b_width_rounded / b_height_in_wr + + b_height_rounded = self.round_to_steps(resized_height) + b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio) + ar_height_rounded = b_width_in_hr / b_height_rounded + + # print(b_width_rounded, b_height_in_wr, ar_width_rounded) + # print(b_width_in_hr, b_height_rounded, ar_height_rounded) + + if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio): + resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5)) + else: + resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded) + # print(resized_size) else: resized_size = (image_width, image_height) # リサイズは不要 @@ -889,13 +910,14 @@ def debug_dataset(train_dataset, show_input_ids=False): k = 0 for i, example in enumerate(train_dataset): if example['latents'] is not None: - print("sample has latents from npz file") + print(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"') if show_input_ids: print(f"input ids: {iid}") if example['images'] is not None: im = example['images'][j] + print(f"image size: {im.size()}") im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = im[:, :, ::-1] # RGB -> BGR (OpenCV)