mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Compare commits
7 Commits
dev
...
4b07c37f76
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b07c37f76 | ||
|
|
c19c90d66e | ||
|
|
b9ddf0ce37 | ||
|
|
1d4c578ac8 | ||
|
|
8fb676cb69 | ||
|
|
f73221038e | ||
|
|
4446d4b58d |
@@ -64,7 +64,12 @@ from library import custom_train_functions, sd3_utils
|
|||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sys
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
try:
|
||||||
|
from PIL import ImageCms
|
||||||
|
except:
|
||||||
|
print( "ImageCms not available. Images will not be converted to sRGB. Colours may be handled incorrectly." )
|
||||||
import imagesize
|
import imagesize
|
||||||
import cv2
|
import cv2
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@@ -3004,10 +3009,36 @@ def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
|
|||||||
def load_image(image_path, alpha=False):
|
def load_image(image_path, alpha=False):
|
||||||
try:
|
try:
|
||||||
with Image.open(image_path) as image:
|
with Image.open(image_path) as image:
|
||||||
|
if getattr(image, "is_animated", False):
|
||||||
|
logger.warning( f"{image_path} is animated" )
|
||||||
|
|
||||||
|
# Convert image to sRGB
|
||||||
|
if "PIL.ImageCms" in sys.modules:
|
||||||
|
icc = image.info.get('icc_profile', '')
|
||||||
|
if icc:
|
||||||
|
try:
|
||||||
|
src_profile = ImageCms.ImageCmsProfile( BytesIO(icc) )
|
||||||
|
srgb_profile = ImageCms.createProfile("sRGB")
|
||||||
|
ImageCms.profileToProfile(image, src_profile, srgb_profile, inPlace=True)
|
||||||
|
image.info["icc_profile"] = ImageCms.ImageCmsProfile(srgb_profile).tobytes()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning( f"Could not convert {image_path} to sRGB: {src_profile.profile.model} {src_profile.profile.profile_description}\n{e}" )
|
||||||
|
|
||||||
if alpha:
|
if alpha:
|
||||||
if not image.mode == "RGBA":
|
if not image.mode == "RGBA":
|
||||||
image = image.convert("RGBA")
|
image = image.convert("RGBA")
|
||||||
else:
|
else:
|
||||||
|
if image.mode == "P":
|
||||||
|
# Palette images with alpha are easier to handle as RGBA.
|
||||||
|
image = image.convert('RGBA')
|
||||||
|
|
||||||
|
if "A" in image.getbands():
|
||||||
|
# Replace transparency with white background.
|
||||||
|
alpha_layer = image.convert('RGBA').split()[-1]
|
||||||
|
bg = Image.new("RGBA", image.size, (255, 255, 255, 255) )
|
||||||
|
bg.paste( image, mask=alpha_layer )
|
||||||
|
image = bg.convert('RGB')
|
||||||
|
|
||||||
if not image.mode == "RGB":
|
if not image.mode == "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
img = np.array(image, np.uint8)
|
img = np.array(image, np.uint8)
|
||||||
|
|||||||
Reference in New Issue
Block a user