Fix samples, LoRA training. Add system prompt, use_flash_attn

This commit is contained in:
rockerBOO
2025-02-23 01:29:18 -05:00
parent 6597631b90
commit 025cca699b
10 changed files with 888 additions and 386 deletions

View File

@@ -195,7 +195,7 @@ class ImageInfo:
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: Optional[str] = None
@@ -211,6 +211,8 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.system_prompt: Optional[str] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
@@ -434,6 +436,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -464,6 +467,8 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.system_prompt = system_prompt
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.is_reg = is_reg
@@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.metadata_file = metadata_file
@@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
system_prompt: Optional[str] = None
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
system_prompt=system_prompt
)
self.conditioning_data_dir = conditioning_data_dir
@@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
system_prompt = subset.system_prompt or ""
caption = self.process_caption(subset, image_info.caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
@@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images = 0
num_reg_images = 0
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
num_repeats = subset.num_repeats if self.is_training_dataset else 1
if num_repeats < 1:
@@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2967,7 +2980,7 @@ def trim_and_resize_if_required(
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size