mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Fix samples, LoRA training. Add system prompt, use_flash_attn
This commit is contained in:
@@ -195,7 +195,7 @@ class ImageInfo:
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int, int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
@@ -211,6 +211,8 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
self.system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
@@ -434,6 +436,7 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -464,6 +467,8 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -495,6 +500,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -522,6 +528,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -564,6 +571,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -591,6 +599,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -629,6 +638,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -656,6 +666,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -1686,8 +1697,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
system_prompt = subset.system_prompt or ""
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(system_prompt + caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
@@ -2059,6 +2071,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images = 0
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
|
||||
for subset in subsets:
|
||||
num_repeats = subset.num_repeats if self.is_training_dataset else 1
|
||||
if num_repeats < 1:
|
||||
@@ -2086,7 +2099,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2967,7 +2980,7 @@ def trim_and_resize_if_required(
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user