mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update downsampling for larger image in no_upscale
This commit is contained in:
@@ -123,6 +123,10 @@ class BucketManager():
|
||||
self.buckets.append([])
|
||||
# print(reso, bucket_id, len(self.buckets))
|
||||
|
||||
def round_to_steps(self, x):
|
||||
x = int(x + .5)
|
||||
return x - x % self.reso_steps
|
||||
|
||||
def select_bucket(self, image_width, image_height):
|
||||
aspect_ratio = image_width / image_height
|
||||
if not self.no_upscale:
|
||||
@@ -150,7 +154,24 @@ class BucketManager():
|
||||
resized_height = self.max_area / resized_width
|
||||
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
||||
|
||||
resized_size = (int(resized_width + .5), int(resized_height + .5))
|
||||
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
||||
# 元のbucketingと同じロジック
|
||||
b_width_rounded = self.round_to_steps(resized_width)
|
||||
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
||||
ar_width_rounded = b_width_rounded / b_height_in_wr
|
||||
|
||||
b_height_rounded = self.round_to_steps(resized_height)
|
||||
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
||||
ar_height_rounded = b_width_in_hr / b_height_rounded
|
||||
|
||||
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
||||
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
||||
|
||||
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
||||
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
||||
else:
|
||||
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
||||
# print(resized_size)
|
||||
else:
|
||||
resized_size = (image_width, image_height) # リサイズは不要
|
||||
|
||||
@@ -889,13 +910,14 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
k = 0
|
||||
for i, example in enumerate(train_dataset):
|
||||
if example['latents'] is not None:
|
||||
print("sample has latents from npz file")
|
||||
print(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example['images'] is not None:
|
||||
im = example['images'][j]
|
||||
print(f"image size: {im.size()}")
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
||||
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
||||
|
||||
Reference in New Issue
Block a user