Fix system prompt in datasets

This commit is contained in:
rockerBOO
2025-02-23 13:48:37 -05:00
parent 6d7bec8a37
commit 42a801514c
2 changed files with 5 additions and 2 deletions

View File

@@ -280,7 +280,7 @@ def sample_image_inference(
generator=generator,
)
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0, use_karras_sigmas=True)
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
# if controlnet_image is not None:

View File

@@ -1869,6 +1869,7 @@ class DreamBoothDataset(BaseDataset):
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
system_prompt: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
@@ -1881,6 +1882,7 @@ class DreamBoothDataset(BaseDataset):
self.is_training_dataset = is_training_dataset
self.validation_seed = validation_seed
self.validation_split = validation_split
self.system_prompt = system_prompt
self.enable_bucket = enable_bucket
if self.enable_bucket:
@@ -2098,8 +2100,9 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += num_repeats * len(img_paths)
system_prompt = self.system_prompt or subset.system_prompt or ""
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, subset.system_prompt or "" + caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, system_prompt + caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg: