mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix flip_aug, alpha_mask, random_crop issue in caching
This commit is contained in:
@@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
- transformers, accelerate and huggingface_hub are updated.
|
||||
- If you encounter any issues, please report them.
|
||||
|
||||
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
|
||||
|
||||
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
||||
|
||||
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
||||
|
||||
@@ -998,9 +998,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
@@ -1021,28 +1038,31 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
@@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||
alpha_mask = example["alpha_masks"][j]
|
||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
||||
alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
|
||||
if os.name == "nt":
|
||||
cv2.imshow("alpha_mask", alpha_mask)
|
||||
|
||||
@@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common(
|
||||
|
||||
|
||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu')
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
|
||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||
if args.huber_schedule == "exponential":
|
||||
|
||||
Reference in New Issue
Block a user