update README, format code

This commit is contained in:
Kohya S
2024-09-07 10:45:18 +09:00
parent 16bb5699ac
commit 0005867ba5
3 changed files with 10 additions and 3 deletions

View File

@@ -2094,7 +2094,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group
def load_image(image_path, alpha=False):
def load_image(image_path, alpha=False):
try:
with Image.open(image_path) as image:
if alpha:

View File

@@ -11,6 +11,7 @@ import cv2
from PIL import Image
import numpy as np
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation=Image.LANCZOS):
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# use Pillow resize
@@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
# TODO make inf_utils.py