mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
feat: Florence-2 captioninig (WIP)
This commit is contained in:
232
finetune/caption_images_by_florence2.py
Normal file
232
finetune/caption_images_by_florence2.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# add caption to images by Florence-2
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
from library import device_utils, train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import tagger_utils
|
||||
|
||||
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.load_archive == (
|
||||
args.metadata is not None
|
||||
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
|
||||
|
||||
device = args.device if args.device is not None else device_utils.get_preferred_device()
|
||||
if type(device) is str:
|
||||
device = torch.device(device)
|
||||
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
|
||||
logger.info(f"device: {device}, dtype: {torch_dtype}")
|
||||
|
||||
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
|
||||
|
||||
support_flash_attn = False
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
support_flash_attn = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if support_flash_attn:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
else:
|
||||
logger.info(
|
||||
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
|
||||
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
|
||||
from transformers.dynamic_module_utils import get_imports
|
||||
from unittest.mock import patch
|
||||
|
||||
def fixed_get_imports(filename) -> list[str]:
|
||||
if not str(filename).endswith("modeling_florence2.py"):
|
||||
return get_imports(filename)
|
||||
imports = get_imports(filename)
|
||||
imports.remove("flash_attn")
|
||||
return imports
|
||||
|
||||
# workaround for unnecessary flash_attn requirement
|
||||
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
|
||||
).to(device)
|
||||
|
||||
model.eval()
|
||||
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
||||
|
||||
# 画像を読み込む
|
||||
if not args.load_archive:
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
else:
|
||||
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
|
||||
os.path.join(args.train_data_dir, "*.tar")
|
||||
)
|
||||
image_paths = [Path(archive_file) for archive_file in archive_files]
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = tagger_utils.load_metadata(args.metadata)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# define preprocess_image function
|
||||
def preprocess_image(image: Image.Image):
|
||||
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
|
||||
return inputs
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
# Loader returns: list of (image_path, processed_image_or_something, image_size)
|
||||
if args.load_archive:
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
else:
|
||||
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
|
||||
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
|
||||
|
||||
def run_batch(
|
||||
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
|
||||
images_metadata: Optional[dict[str, Any]],
|
||||
caption_index: Optional[int] = None,
|
||||
):
|
||||
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
|
||||
|
||||
if args.debug:
|
||||
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_beams=args.num_beams,
|
||||
)
|
||||
if args.debug:
|
||||
logger.info(f"generate done: {generated_ids.shape}")
|
||||
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
|
||||
if args.debug:
|
||||
logger.info(f"decode done: {len(generated_texts)}")
|
||||
|
||||
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
|
||||
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
|
||||
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
|
||||
|
||||
caption_text = caption_text.strip().replace("<pad>", "")
|
||||
original_caption_text = caption_text
|
||||
|
||||
if args.remove_mood:
|
||||
p = caption_text.find("The overall ")
|
||||
if p != -1:
|
||||
caption_text = caption_text[:p].strip()
|
||||
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
if images_metadata is None:
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(caption_text + "\n")
|
||||
else:
|
||||
image_md = images_metadata.get(image_path, None)
|
||||
if image_md is None:
|
||||
image_md = {"image_size": list(image_size)}
|
||||
images_metadata[image_path] = image_md
|
||||
if "caption" not in image_md:
|
||||
image_md["caption"] = []
|
||||
if caption_index is None:
|
||||
image_md["caption"].append(caption_text)
|
||||
else:
|
||||
while len(image_md["caption"]) <= caption_index:
|
||||
image_md["caption"].append("")
|
||||
image_md["caption"][caption_index] = caption_text
|
||||
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tCaption: {caption_text}")
|
||||
if args.remove_mood and original_caption_text != caption_text:
|
||||
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
|
||||
|
||||
for data_entry in tqdm(loader, smoothing=0.0):
|
||||
b_imgs = data_entry
|
||||
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
|
||||
run_batch(b_imgs, images_metadata, args.caption_index)
|
||||
|
||||
if args.metadata is not None:
|
||||
logger.info(f"saving metadata file: {args.metadata}")
|
||||
with open(args.metadata, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
|
||||
)
|
||||
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
|
||||
parser.add_argument(
|
||||
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=3,
|
||||
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
|
||||
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,13 +1,10 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
import zipfile
|
||||
import tarfile
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -16,14 +13,17 @@ from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import pil_resize
|
||||
import tagger_utils
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
|
||||
@@ -79,83 +79,6 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
return (image, img_path, size)
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False):
|
||||
self.archive_paths = archive_paths
|
||||
self.batch_size = batch_size
|
||||
self.debug = debug
|
||||
self.current_archive = None
|
||||
self.archive_index = 0
|
||||
self.image_index = 0
|
||||
self.files = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
images = []
|
||||
while len(images) < self.batch_size:
|
||||
if self.current_archive is None:
|
||||
if self.archive_index >= len(self.archive_paths):
|
||||
if len(images) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
break # return the remaining images
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
|
||||
|
||||
current_archive_path = self.archive_paths[self.archive_index]
|
||||
if current_archive_path.endswith(".zip"):
|
||||
self.current_archive = zipfile.ZipFile(current_archive_path)
|
||||
self.files = self.current_archive.namelist()
|
||||
elif current_archive_path.endswith(".tar"):
|
||||
self.current_archive = tarfile.open(current_archive_path, "r")
|
||||
self.files = self.current_archive.getnames()
|
||||
else:
|
||||
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
|
||||
|
||||
self.image_index = 0
|
||||
|
||||
# filter by image extensions
|
||||
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"found {len(self.files)} images in the archive")
|
||||
|
||||
new_images = []
|
||||
while len(images) + len(new_images) < self.batch_size:
|
||||
if self.image_index >= len(self.files):
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
with archive.open(file) as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
size = image.size
|
||||
image = preprocess_image(image)
|
||||
return image, size
|
||||
|
||||
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
|
||||
|
||||
# wait for all new_images to load to close the archive
|
||||
new_images = [(image_path, future.result()) for image_path, future in new_images]
|
||||
|
||||
if self.image_index >= len(self.files):
|
||||
self.current_archive.close()
|
||||
self.current_archive = None
|
||||
self.archive_index += 1
|
||||
|
||||
images.extend(new_images)
|
||||
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def collate_fn_remove_corrupted(batch):
|
||||
"""Collate function that allows to remove corrupted examples in the
|
||||
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
||||
@@ -460,33 +383,16 @@ def main(args):
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
|
||||
# load metadata if needed
|
||||
metadata = None
|
||||
if args.metadata is not None:
|
||||
if os.path.exists(args.metadata):
|
||||
logger.info(f"loading metadata file: {args.metadata}")
|
||||
with open(args.metadata, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
logger.info(f"metadata file not found: {args.metadata}, creating new metadata")
|
||||
metadata = {"format_version": "1.0.0", "images": {}}
|
||||
|
||||
metadata = tagger_utils.load_metadata(args.metadata)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
# prepare DataLoader or something similar :)
|
||||
use_loader = False
|
||||
if args.load_archive:
|
||||
loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size)
|
||||
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
|
||||
use_loader = True
|
||||
elif args.max_data_loader_n_workers is not None:
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
@@ -655,12 +561,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tags_index",
|
||||
type=int,
|
||||
@@ -668,12 +568,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
|
||||
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_archive",
|
||||
action="store_true",
|
||||
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
|
||||
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
|
||||
)
|
||||
tagger_utils.add_archive_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
171
finetune/tagger_utils.py
Normal file
171
finetune/tagger_utils.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Union
|
||||
import zipfile
|
||||
import tarfile
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.archive_paths = archive_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.current_archive = None
|
||||
self.archive_index = 0
|
||||
self.image_index = 0
|
||||
self.files = None
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
images = []
|
||||
while len(images) < self.batch_size:
|
||||
if self.current_archive is None:
|
||||
if self.archive_index >= len(self.archive_paths):
|
||||
if len(images) == 0:
|
||||
raise StopIteration
|
||||
else:
|
||||
break # return the remaining images
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
|
||||
|
||||
current_archive_path = self.archive_paths[self.archive_index]
|
||||
if current_archive_path.endswith(".zip"):
|
||||
self.current_archive = zipfile.ZipFile(current_archive_path)
|
||||
self.files = self.current_archive.namelist()
|
||||
elif current_archive_path.endswith(".tar"):
|
||||
self.current_archive = tarfile.open(current_archive_path, "r")
|
||||
self.files = self.current_archive.getnames()
|
||||
else:
|
||||
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
|
||||
|
||||
self.image_index = 0
|
||||
|
||||
# filter by image extensions
|
||||
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
|
||||
|
||||
if self.debug:
|
||||
logger.info(f"found {len(self.files)} images in the archive")
|
||||
|
||||
new_images = []
|
||||
while len(images) + len(new_images) < self.batch_size:
|
||||
if self.image_index >= len(self.files):
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
with archive.open(file) as f:
|
||||
image = Image.open(f).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
|
||||
|
||||
# wait for all new_images to load to close the archive
|
||||
new_images = [(image_path, future.result()) for image_path, future in new_images]
|
||||
|
||||
if self.image_index >= len(self.files):
|
||||
self.current_archive.close()
|
||||
self.current_archive = None
|
||||
self.archive_index += 1
|
||||
|
||||
images.extend(new_images)
|
||||
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
class ImageLoader:
|
||||
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
|
||||
self.image_paths = image_paths
|
||||
self.batch_size = batch_size
|
||||
self.preprocess = preprocess
|
||||
self.debug = debug
|
||||
self.image_index = 0
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.image_paths) / self.batch_size)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.image_index >= len(self.image_paths):
|
||||
raise StopIteration
|
||||
|
||||
images = []
|
||||
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
|
||||
|
||||
def load_image(file):
|
||||
image = Image.open(file).convert("RGB")
|
||||
size = image.size
|
||||
image = self.preprocess(image)
|
||||
return image, size
|
||||
|
||||
image_path = self.image_paths[self.image_index]
|
||||
images.append((image_path, self.executor.submit(load_image, image_path)))
|
||||
self.image_index += 1
|
||||
|
||||
images = [(image_path, future.result()) for image_path, future in images]
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def load_metadata(metadata_file: str):
|
||||
if os.path.exists(metadata_file):
|
||||
logger.info(f"loading metadata file: {metadata_file}")
|
||||
with open(metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
|
||||
metadata = {"format_version": "1.0.0", "images": {}}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def add_archive_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_archive",
|
||||
action="store_true",
|
||||
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
|
||||
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
|
||||
)
|
||||
Reference in New Issue
Block a user