fix to work

This commit is contained in:
Kohya S
2024-11-29 21:59:25 +09:00
parent 2238b94e7b
commit 744cf03136

View File

@@ -125,6 +125,7 @@ class ArchiveImageLoader:
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
@@ -166,6 +167,10 @@ def collate_fn_remove_corrupted(batch):
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
@@ -436,7 +441,7 @@ def main(args):
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": [image_size.width, image_size.height]}
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
@@ -464,6 +469,7 @@ def main(args):
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
@@ -480,7 +486,7 @@ def main(args):
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = ArchiveImageLoader(image_paths, args.batch_size)
loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション