This commit is contained in:
unknown
2023-01-29 22:30:45 +09:00
parent 67e698af67
commit c20745b6e8
5 changed files with 7 additions and 6 deletions

View File

@@ -31,7 +31,7 @@ def main(args):
os.chdir('finetune')
print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")
@@ -30,7 +30,8 @@ def main(args):
for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
lines = f.readlines()
caption = lines[0].strip() if len(lines) > 0 else ""
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")

View File

@@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype):
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")

View File

@@ -36,7 +36,7 @@ def main(args):
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
# 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.")