diff --git a/library/utils.py b/library/utils.py index 2171c719..8a0c782c 100644 --- a/library/utils.py +++ b/library/utils.py @@ -305,25 +305,19 @@ class MemoryEfficientSafeOpen: raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)") def pil_resize(image, size, interpolation=Image.LANCZOS): - # Check if the image has an alpha channel has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: - # Convert BGRA to RGBA pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: - # Convert BGR to RGB pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Resize the image resized_pil = pil_image.resize(size, interpolation) # Convert back to cv2 format if has_alpha: - # Convert RGBA to BGRA resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: - # Convert RGB to BGR resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2