mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Update embedder_dims, add more flexible caption extension
This commit is contained in:
@@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
# if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
# self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
for base, cap_extension in cap_paths:
|
||||
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
|
||||
for cap_path in [base + cap_extension, base + "." + cap_extension]:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
|
||||
Reference in New Issue
Block a user