show file name if error in load_image ref #1385

This commit is contained in:
Kohya S
2024-06-25 20:03:09 +09:00
parent 9dd1ee458c
commit 0b3e4f7ab6

View File

@@ -2434,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group return train_dataset_group
def load_image(image_path, alpha=False): def load_image(image_path, alpha=False):
image = Image.open(image_path) try:
if alpha: with Image.open(image_path) as image:
if not image.mode == "RGBA": if alpha:
image = image.convert("RGBA") if not image.mode == "RGBA":
else: image = image.convert("RGBA")
if not image.mode == "RGB": else:
image = image.convert("RGB") if not image.mode == "RGB":
img = np.array(image, np.uint8) image = image.convert("RGB")
return img img = np.array(image, np.uint8)
return img
except (IOError, OSError) as e:
logger.error(f"Error loading file: {image_path}")
raise e
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)