feat: Florence-2 captioninig (WIP)

This commit is contained in:
Kohya S
2024-12-05 22:04:37 +09:00
parent b72b9eaf11
commit 28e9352cc5
3 changed files with 415 additions and 116 deletions

View File

@@ -0,0 +1,232 @@
# add caption to images by Florence-2
import argparse
import json
import os
import glob
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
from library import device_utils, train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import tagger_utils
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
device = args.device if args.device is not None else device_utils.get_preferred_device()
if type(device) is str:
device = torch.device(device)
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
logger.info(f"device: {device}, dtype: {torch_dtype}")
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
support_flash_attn = False
try:
import flash_attn
support_flash_attn = True
except ImportError:
pass
if support_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
else:
logger.info(
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
)
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
from transformers.dynamic_module_utils import get_imports
from unittest.mock import patch
def fixed_get_imports(filename) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# workaround for unnecessary flash_attn requirement
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
# load metadata if needed
if args.metadata is not None:
metadata = tagger_utils.load_metadata(args.metadata)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# define preprocess_image function
def preprocess_image(image: Image.Image):
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
return inputs
# prepare DataLoader or something similar :)
# Loader returns: list of (image_path, processed_image_or_something, image_size)
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
else:
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
def run_batch(
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
caption_index: Optional[int] = None,
):
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
if args.debug:
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
)
if args.debug:
logger.info(f"generate done: {generated_ids.shape}")
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
if args.debug:
logger.info(f"decode done: {len(generated_texts)}")
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
caption_text = caption_text.strip().replace("<pad>", "")
original_caption_text = caption_text
if args.remove_mood:
p = caption_text.find("The overall ")
if p != -1:
caption_text = caption_text[:p].strip()
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(caption_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "caption" not in image_md:
image_md["caption"] = []
if caption_index is None:
image_md["caption"].append(caption_text)
else:
while len(image_md["caption"]) <= caption_index:
image_md["caption"].append("")
image_md["caption"][caption_index] = caption_text
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCaption: {caption_text}")
if args.remove_mood and original_caption_text != caption_text:
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
for data_entry in tqdm(loader, smoothing=0.0):
b_imgs = data_entry
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.caption_index)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
parser.add_argument(
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
)
parser.add_argument(
"--num_beams",
type=int,
default=3,
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
)
parser.add_argument(
"--caption_index",
type=int,
default=None,
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
tagger_utils.add_archive_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)

View File

@@ -1,13 +1,10 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import csv
import glob
import json
import os
from pathlib import Path
from typing import Any, Optional, Union
import zipfile
import tarfile
from typing import Any, Optional
import cv2
import numpy as np
@@ -16,14 +13,17 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
from library.utils import pil_resize
import tagger_utils
# from wd14 tagger
IMAGE_SIZE = 448
@@ -79,83 +79,6 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
return (image, img_path, size)
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = preprocess_image(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
@@ -460,33 +383,16 @@ def main(args):
logger.info(f"\tGeneral tags: {general_tag_text}")
# load metadata if needed
metadata = None
if args.metadata is not None:
if os.path.exists(args.metadata):
logger.info(f"loading metadata file: {args.metadata}")
with open(args.metadata, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
logger.info(f"metadata file not found: {args.metadata}, creating new metadata")
metadata = {"format_version": "1.0.0", "images": {}}
metadata = tagger_utils.load_metadata(args.metadata)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size)
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -655,12 +561,6 @@ def setup_parser() -> argparse.ArgumentParser:
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--tags_index",
type=int,
@@ -668,12 +568,8 @@ def setup_parser() -> argparse.ArgumentParser:
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)
tagger_utils.add_archive_arguments(parser)
return parser

171
finetune/tagger_utils.py Normal file
View File

@@ -0,0 +1,171 @@
import argparse
import json
import math
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Union
import zipfile
import tarfile
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import train_util
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
class ImageLoader:
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.image_paths = image_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.image_index = 0
self.executor = ThreadPoolExecutor()
def __len__(self):
return math.ceil(len(self.image_paths) / self.batch_size)
def __iter__(self):
return self
def __next__(self):
if self.image_index >= len(self.image_paths):
raise StopIteration
images = []
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
def load_image(file):
image = Image.open(file).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
image_path = self.image_paths[self.image_index]
images.append((image_path, self.executor.submit(load_image, image_path)))
self.image_index += 1
images = [(image_path, future.result()) for image_path, future in images]
return [(image_path, image, size) for image_path, (image, size) in images]
def load_metadata(metadata_file: str):
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": "1.0.0", "images": {}}
return metadata
def add_archive_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)