Revert system_prompt for dataset config

This commit is contained in:
rockerBOO
2025-06-16 16:50:18 -04:00
parent 1db78559a6
commit 0e929f97b9

View File

@@ -192,7 +192,7 @@ class ImageInfo:
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: Optional[str] = None
@@ -209,8 +209,6 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
self.system_prompt: Optional[str] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -434,7 +432,6 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
) -> None:
self.image_dir = image_dir
@@ -466,7 +463,6 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.system_prompt = system_prompt
self.resize_interpolation = resize_interpolation
@@ -500,7 +496,6 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -529,15 +524,14 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt,
resize_interpolation=resize_interpolation,
)
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
# if self.caption_extension and not self.caption_extension.startswith("."):
# self.caption_extension = "." + self.caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
@@ -573,7 +567,6 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -602,7 +595,6 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt,
resize_interpolation=resize_interpolation,
)
@@ -642,7 +634,6 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -671,7 +662,6 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt,
resize_interpolation=resize_interpolation,
)
@@ -1713,10 +1703,8 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else ""
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
@@ -1886,8 +1874,7 @@ class DreamBoothDataset(BaseDataset):
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
@@ -1900,7 +1887,6 @@ class DreamBoothDataset(BaseDataset):
self.is_training_dataset = is_training_dataset
self.validation_seed = validation_seed
self.validation_split = validation_split
self.system_prompt = system_prompt
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -1917,33 +1903,30 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
def read_caption(img_path, caption_extension, enable_wildcard):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for base, cap_extension in cap_paths:
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
for cap_path in [base + cap_extension, base + "." + cap_extension]:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
break
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):
@@ -2090,7 +2073,6 @@ class DreamBoothDataset(BaseDataset):
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
num_repeats = subset.num_repeats if self.is_training_dataset else 1
if num_repeats < 1:
@@ -2117,10 +2099,8 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += num_repeats * len(img_paths)
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
@@ -2177,8 +2157,7 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
system_prompt: Optional[str] = None,
resize_interpolation: Optional[str] = None,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
@@ -2406,8 +2385,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
system_prompt: Optional[str] = None,
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
@@ -2461,7 +2439,6 @@ class ControlNetDataset(BaseDataset):
debug_dataset,
validation_split,
validation_seed,
system_prompt,
resize_interpolation,
)
@@ -3005,7 +2982,7 @@ def trim_and_resize_if_required(
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
@@ -6241,6 +6218,7 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)