mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix flip_aug, alpha_mask, random_crop issue in caching
This commit is contained in:
@@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
|||||||
- transformers, accelerate and huggingface_hub are updated.
|
- transformers, accelerate and huggingface_hub are updated.
|
||||||
- If you encounter any issues, please report them.
|
- If you encounter any issues, please report them.
|
||||||
|
|
||||||
|
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
|
||||||
|
|
||||||
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
||||||
|
|
||||||
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
||||||
|
|||||||
@@ -998,9 +998,26 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
# sort by resolution
|
# sort by resolution
|
||||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||||
|
|
||||||
# split by resolution
|
# split by resolution and some conditions
|
||||||
batches = []
|
class Condition:
|
||||||
batch = []
|
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||||
|
self.reso = reso
|
||||||
|
self.flip_aug = flip_aug
|
||||||
|
self.alpha_mask = alpha_mask
|
||||||
|
self.random_crop = random_crop
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
self.reso == other.reso
|
||||||
|
and self.flip_aug == other.flip_aug
|
||||||
|
and self.alpha_mask == other.alpha_mask
|
||||||
|
and self.random_crop == other.random_crop
|
||||||
|
)
|
||||||
|
|
||||||
|
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||||
|
batch: List[ImageInfo] = []
|
||||||
|
current_condition = None
|
||||||
|
|
||||||
logger.info("checking cache validity...")
|
logger.info("checking cache validity...")
|
||||||
for info in tqdm(image_infos):
|
for info in tqdm(image_infos):
|
||||||
subset = self.image_to_subset[info.image_key]
|
subset = self.image_to_subset[info.image_key]
|
||||||
@@ -1021,28 +1038,31 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if cache_available: # do not add to batch
|
if cache_available: # do not add to batch
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# if last member of batch has different resolution, flush the batch
|
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
batches.append(batch)
|
if len(batch) > 0 and current_condition != condition:
|
||||||
|
batches.append((current_condition, batch))
|
||||||
batch = []
|
batch = []
|
||||||
|
|
||||||
batch.append(info)
|
batch.append(info)
|
||||||
|
current_condition = condition
|
||||||
|
|
||||||
# if number of data in batch is enough, flush the batch
|
# if number of data in batch is enough, flush the batch
|
||||||
if len(batch) >= vae_batch_size:
|
if len(batch) >= vae_batch_size:
|
||||||
batches.append(batch)
|
batches.append((current_condition, batch))
|
||||||
batch = []
|
batch = []
|
||||||
|
current_condition = None
|
||||||
|
|
||||||
if len(batch) > 0:
|
if len(batch) > 0:
|
||||||
batches.append(batch)
|
batches.append((current_condition, batch))
|
||||||
|
|
||||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||||
return
|
return
|
||||||
|
|
||||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||||
logger.info("caching latents...")
|
logger.info("caching latents...")
|
||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||||
|
|
||||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||||
@@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
if "alpha_masks" in example and example["alpha_masks"] is not None:
|
||||||
alpha_mask = example["alpha_masks"][j]
|
alpha_mask = example["alpha_masks"][j]
|
||||||
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
logger.info(f"alpha mask size: {alpha_mask.size()}")
|
||||||
alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8)
|
alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8)
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
cv2.imshow("alpha_mask", alpha_mask)
|
cv2.imshow("alpha_mask", alpha_mask)
|
||||||
|
|
||||||
@@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common(
|
|||||||
|
|
||||||
|
|
||||||
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu')
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||||
|
|
||||||
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
|
||||||
if args.huber_schedule == "exponential":
|
if args.huber_schedule == "exponential":
|
||||||
|
|||||||
Reference in New Issue
Block a user