Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-09-26 20:52:08 +09:00
2 changed files with 34 additions and 12 deletions

View File

@@ -710,6 +710,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- transformers, accelerate and huggingface_hub are updated. - transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them. - If you encounter any issues, please report them.
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
- Improvements in OFT (Orthogonal Finetuning) Implementation - Improvements in OFT (Orthogonal Finetuning) Implementation

View File

@@ -1054,9 +1054,26 @@ class BaseDataset(torch.utils.data.Dataset):
# sort by resolution # sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution # split by resolution and some conditions
batches = [] class Condition:
batch = [] def __init__(self, reso, flip_aug, alpha_mask, random_crop):
self.reso = reso
self.flip_aug = flip_aug
self.alpha_mask = alpha_mask
self.random_crop = random_crop
def __eq__(self, other):
return (
self.reso == other.reso
and self.flip_aug == other.flip_aug
and self.alpha_mask == other.alpha_mask
and self.random_crop == other.random_crop
)
batches: List[Tuple[Condition, List[ImageInfo]]] = []
batch: List[ImageInfo] = []
current_condition = None
logger.info("checking cache validity...") logger.info("checking cache validity...")
for info in tqdm(image_infos): for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key] subset = self.image_to_subset[info.image_key]
@@ -1077,28 +1094,31 @@ class BaseDataset(torch.utils.data.Dataset):
if cache_available: # do not add to batch if cache_available: # do not add to batch
continue continue
# if last member of batch has different resolution, flush the batch # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
batches.append(batch) if len(batch) > 0 and current_condition != condition:
batches.append((current_condition, batch))
batch = [] batch = []
batch.append(info) batch.append(info)
current_condition = condition
# if number of data in batch is enough, flush the batch # if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size: if len(batch) >= vae_batch_size:
batches.append(batch) batches.append((current_condition, batch))
batch = [] batch = []
current_condition = None
if len(batch) > 0: if len(batch) > 0:
batches.append(batch) batches.append((current_condition, batch))
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...") logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)): for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool): def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
r""" r"""
@@ -2516,7 +2536,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if "alpha_masks" in example and example["alpha_masks"] is not None: if "alpha_masks" in example and example["alpha_masks"] is not None:
alpha_mask = example["alpha_masks"][j] alpha_mask = example["alpha_masks"][j]
logger.info(f"alpha mask size: {alpha_mask.size()}") logger.info(f"alpha mask size: {alpha_mask.size()}")
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
if os.name == "nt": if os.name == "nt":
cv2.imshow("alpha_mask", alpha_mask) cv2.imshow("alpha_mask", alpha_mask)
@@ -5535,7 +5555,7 @@ def save_sd_model_on_train_end_common(
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential": if args.huber_schedule == "exponential":