Merge branch 'dev' into dev_device_support

This commit is contained in:
Kohya S
2024-02-12 13:01:54 +09:00
62 changed files with 1387 additions and 993 deletions

View File

@@ -14,8 +14,12 @@ from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
@@ -38,8 +42,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps
@@ -73,16 +77,16 @@ def main(args):
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")
# captioningする
def run_batch(path_imgs):
@@ -100,7 +104,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -129,7 +133,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -140,7 +144,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: