mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge pull request #129 from p1atdev/main
Add support for .jpeg images in glob
This commit is contained in:
@@ -31,7 +31,7 @@ def main(args):
|
||||
os.chdir('finetune')
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
@@ -30,7 +30,8 @@ def main(args):
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
with open(caption_path, "rt", encoding='utf-8') as f:
|
||||
caption = f.readlines()[0].strip()
|
||||
lines = f.readlines()
|
||||
caption = lines[0].strip() if len(lines) > 0 else ""
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
if image_key not in metadata:
|
||||
|
||||
@@ -9,15 +9,16 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
image_paths = None
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
if args.recursive:
|
||||
image_paths = list(train_data_dir_path.rglob('*.jpg')) + \
|
||||
list(train_data_dir_path.rglob('*.jpeg')) + \
|
||||
list(train_data_dir_path.rglob('*.png')) + \
|
||||
list(train_data_dir_path.rglob('*.webp'))
|
||||
else:
|
||||
image_paths = list(train_data_dir_path.glob('*.jpg')) + \
|
||||
list(train_data_dir_path.glob('*.jpeg')) + \
|
||||
list(train_data_dir_path.glob('*.png')) + \
|
||||
list(train_data_dir_path.glob('*.webp'))
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_latents(vae, images, weight_dtype):
|
||||
|
||||
|
||||
def main(args):
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ def main(args):
|
||||
|
||||
# 画像を読み込む
|
||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.webp")) + \
|
||||
glob.glob(os.path.join(args.train_data_dir, "*.bmp"))
|
||||
|
||||
Reference in New Issue
Block a user