diff --git a/library/train_util.py b/library/train_util.py index bd2ff6ef..18d3cf6c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1082,6 +1082,10 @@ class BaseDataset(torch.utils.data.Dataset): info.image = info.image.result() # future to image caching_strategy.cache_batch_latents(model, batch, cond.flip_aug, cond.alpha_mask, cond.random_crop) + # remove image from memory + for info in batch: + info.image = None + # define ThreadPoolExecutor to load images in parallel max_workers = min(os.cpu_count(), len(image_infos)) max_workers = max(1, max_workers // num_processes) # consider multi-gpu @@ -1397,7 +1401,17 @@ class BaseDataset(torch.utils.data.Dataset): ) def get_image_size(self, image_path): - return imagesize.get(image_path) + # return imagesize.get(image_path) + image_size = imagesize.get(image_path) + if image_size[0] <= 0: + # imagesize doesn't work for some images, so use cv2 + img = cv2.imread(image_path) + if img is not None: + image_size = (img.shape[1], img.shape[0]) + else: + logger.warning(f"failed to get image size: {image_path}") + image_size = (0, 0) + return image_size def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask)