diff --git a/library/train_util.py b/library/train_util.py index 68019e21..1d80bcd8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -192,7 +192,7 @@ class ImageInfo: self.latents_flipped: Optional[torch.Tensor] = None self.latents_npz: Optional[str] = None # set in cache_latents self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size - self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = ( + self.latents_crop_ltrb: Optional[Tuple[int, int]] = ( None # crop left top right bottom in original pixel size, not latents size ) self.cond_img_path: Optional[str] = None @@ -209,8 +209,6 @@ class ImageInfo: self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.resize_interpolation: Optional[str] = None - self.system_prompt: Optional[str] = None - class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -434,7 +432,6 @@ class BaseSubset: custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: self.image_dir = image_dir @@ -466,7 +463,6 @@ class BaseSubset: self.validation_seed = validation_seed self.validation_split = validation_split - self.system_prompt = system_prompt self.resize_interpolation = resize_interpolation @@ -500,7 +496,6 @@ class DreamBoothSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -529,15 +524,14 @@ class DreamBoothSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension - # if self.caption_extension and not self.caption_extension.startswith("."): - # self.caption_extension = "." + self.caption_extension + if self.caption_extension and not self.caption_extension.startswith("."): + self.caption_extension = "." + self.caption_extension self.cache_info = cache_info def __eq__(self, other) -> bool: @@ -573,7 +567,6 @@ class FineTuningSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -602,7 +595,6 @@ class FineTuningSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -642,7 +634,6 @@ class ControlNetSubset(BaseSubset): custom_attributes: Optional[Dict[str, Any]] = None, validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, - system_prompt: Optional[str] = None, resize_interpolation: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -671,7 +662,6 @@ class ControlNetSubset(BaseSubset): custom_attributes=custom_attributes, validation_seed=validation_seed, validation_split=validation_split, - system_prompt=system_prompt, resize_interpolation=resize_interpolation, ) @@ -1713,10 +1703,8 @@ class BaseDataset(torch.utils.data.Dataset): text_encoder_outputs_list.append(text_encoder_outputs) if tokenization_required: - system_prompt_special_token = "" - system_prompt = f"{subset.system_prompt} {system_prompt_special_token} " if subset.system_prompt else "" caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension # if self.XTI_layers: # caption_layer = [] # for layer in self.XTI_layers: @@ -1886,8 +1874,7 @@ class DreamBoothDataset(BaseDataset): debug_dataset: bool, validation_split: float, validation_seed: Optional[int], - system_prompt: Optional[str] = None, - resize_interpolation: Optional[str] = None, + resize_interpolation: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -1900,7 +1887,6 @@ class DreamBoothDataset(BaseDataset): self.is_training_dataset = is_training_dataset self.validation_seed = validation_seed self.validation_split = validation_split - self.system_prompt = system_prompt self.enable_bucket = enable_bucket if self.enable_bucket: @@ -1917,33 +1903,30 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool): + def read_caption(img_path, caption_extension, enable_wildcard): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name tokens = base_name.split("_") if len(tokens) >= 5: base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)] + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] caption = None - for base, cap_extension in cap_paths: - # check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt) - for cap_path in [base + cap_extension, base + "." + cap_extension]: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - if enable_wildcard: - caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 - else: - caption = lines[0].strip() - break - break + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() + break return caption def load_dreambooth_dir(subset: DreamBoothSubset): @@ -2090,7 +2073,6 @@ class DreamBoothDataset(BaseDataset): num_train_images = 0 num_reg_images = 0 reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = [] - for subset in subsets: num_repeats = subset.num_repeats if self.is_training_dataset else 1 if num_repeats < 1: @@ -2117,10 +2099,8 @@ class DreamBoothDataset(BaseDataset): else: num_train_images += num_repeats * len(img_paths) - system_prompt_special_token = "" - system_prompt = f"{self.system_prompt or subset.system_prompt} {system_prompt_special_token} " if self.system_prompt or subset.system_prompt else "" for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation if size is not None: info.image_size = size @@ -2177,8 +2157,7 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, - system_prompt: Optional[str] = None, - resize_interpolation: Optional[str] = None, + resize_interpolation: Optional[str], ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2406,8 +2385,7 @@ class ControlNetDataset(BaseDataset): bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], - system_prompt: Optional[str] = None, + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2461,7 +2439,6 @@ class ControlNetDataset(BaseDataset): debug_dataset, validation_split, validation_seed, - system_prompt, resize_interpolation, ) @@ -3005,7 +2982,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size @@ -6241,6 +6218,7 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["renorm_cfg"] = float(m.group(1)) continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex)