Update embedder_dims, add more flexible caption extension

This commit is contained in:
rockerBOO
2025-03-04 02:21:05 -05:00
parent 5e45df722d
commit 1f22a94cfe
3 changed files with 159 additions and 121 deletions

View File

@@ -887,6 +887,9 @@ class NextDiT(nn.Module):
),
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
nn.init.zeros_(self.cap_embedder[1].bias)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
@@ -929,9 +932,6 @@ class NextDiT(nn.Module):
]
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
# nn.init.zeros_(self.cap_embedder[1].weight)
nn.init.zeros_(self.cap_embedder[1].bias)
self.layers = nn.ModuleList(
[

View File

@@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
# if self.caption_extension and not self.caption_extension.startswith("."):
# self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
@@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension, enable_wildcard):
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
for base, cap_extension in cap_paths:
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
for cap_path in [base + cap_extension, base + "." + cap_extension]:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):