Update downsampling for larger image in no_upscale

This commit is contained in:
Kohya S
2023-02-06 20:52:24 +09:00
parent ea2dfd09ef
commit 2aa27b7a4b

View File

@@ -123,6 +123,10 @@ class BucketManager():
self.buckets.append([]) self.buckets.append([])
# print(reso, bucket_id, len(self.buckets)) # print(reso, bucket_id, len(self.buckets))
def round_to_steps(self, x):
x = int(x + .5)
return x - x % self.reso_steps
def select_bucket(self, image_width, image_height): def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height aspect_ratio = image_width / image_height
if not self.no_upscale: if not self.no_upscale:
@@ -150,7 +154,24 @@ class BucketManager():
resized_height = self.max_area / resized_width resized_height = self.max_area / resized_width
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal" assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
resized_size = (int(resized_width + .5), int(resized_height + .5)) # リサイズ後の短辺または長辺をreso_steps単位にするaspect ratioの差が少ないほうを選ぶ
# 元のbucketingと同じロジック
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
else:
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
# print(resized_size)
else: else:
resized_size = (image_width, image_height) # リサイズは不要 resized_size = (image_width, image_height) # リサイズは不要
@@ -889,13 +910,14 @@ def debug_dataset(train_dataset, show_input_ids=False):
k = 0 k = 0
for i, example in enumerate(train_dataset): for i, example in enumerate(train_dataset):
if example['latents'] is not None: if example['latents'] is not None:
print("sample has latents from npz file") print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids: if show_input_ids:
print(f"input ids: {iid}") print(f"input ids: {iid}")
if example['images'] is not None: if example['images'] is not None:
im = example['images'][j] im = example['images'][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV) im = im[:, :, ::-1] # RGB -> BGR (OpenCV)