add GIT captioning, refactoring, DataLoader

This commit is contained in:
Kohya S
2023-02-03 08:45:33 +09:00
parent 8c3a52ecc9
commit 57d8483eaf
9 changed files with 479 additions and 144 deletions

View File

@@ -44,7 +44,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
class ImageInfo():
@@ -141,7 +141,7 @@ class BaseDataset(torch.utils.data.Dataset):
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
caption = str_to
else:
caption = caption.replace(str_from, str_to)
@@ -247,7 +247,6 @@ class BaseDataset(torch.utils.data.Dataset):
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
@@ -766,15 +765,30 @@ def debug_dataset(train_dataset, show_input_ids=False):
break
def glob_images(dir, base):
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
img_paths = list(set(img_paths)) # 重複を排除
img_paths.sort()
return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob('*' + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
image_paths = list(set(image_paths)) # 重複を排除
image_paths.sort()
return image_paths
# endregion
@@ -1505,5 +1519,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion
# region 前処理用
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
# endregion