mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Revert system_prompt for dataset config
This commit is contained in:
@@ -192,7 +192,7 @@ class ImageInfo:
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
@@ -209,8 +209,6 @@ class ImageInfo:
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
self.system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -434,7 +432,6 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
@@ -466,7 +463,6 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
|
||||
@@ -500,7 +496,6 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
@@ -529,15 +524,14 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
# if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
# self.caption_extension = "." + self.caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@@ -573,7 +567,6 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
@@ -602,7 +595,6 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
@@ -642,7 +634,6 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
@@ -671,7 +662,6 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
@@ -1713,10 +1703,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else ""
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
@@ -1886,8 +1874,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
@@ -1900,7 +1887,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.is_training_dataset = is_training_dataset
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
@@ -1917,33 +1903,30 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
for base, cap_extension in cap_paths:
|
||||
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
|
||||
for cap_path in [base + cap_extension, base + "." + cap_extension]:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
break
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
@@ -2090,7 +2073,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
|
||||
for subset in subsets:
|
||||
num_repeats = subset.num_repeats if self.is_training_dataset else 1
|
||||
if num_repeats < 1:
|
||||
@@ -2117,10 +2099,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else ""
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
@@ -2177,8 +2157,7 @@ class FineTuningDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
system_prompt: Optional[str] = None,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
@@ -2406,8 +2385,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
system_prompt: Optional[str] = None,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
@@ -2461,7 +2439,6 @@ class ControlNetDataset(BaseDataset):
|
||||
debug_dataset,
|
||||
validation_split,
|
||||
validation_seed,
|
||||
system_prompt,
|
||||
resize_interpolation,
|
||||
)
|
||||
|
||||
@@ -3005,7 +2982,7 @@ def trim_and_resize_if_required(
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
@@ -6241,6 +6218,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["renorm_cfg"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(ex)
|
||||
|
||||
Reference in New Issue
Block a user