mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'kohya-ss:main' into min-SNR
This commit is contained in:
@@ -73,8 +73,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
|
||||
|
||||
# region dataset
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
||||
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
||||
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@@ -675,10 +674,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def cache_latents(self, vae):
|
||||
# TODO ここを高速化したい
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
# ちょっと速くした
|
||||
print("caching latents.")
|
||||
for info in tqdm(self.image_data.values()):
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
for info in image_infos:
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
@@ -689,18 +697,42 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
images = []
|
||||
for info in batch:
|
||||
image = self.load_image(info.absolute_path)
|
||||
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
|
||||
image = self.image_transforms(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents = latent
|
||||
|
||||
if subset.flip_aug:
|
||||
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
||||
img_tensor = self.image_transforms(image)
|
||||
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
||||
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
info.latents_flipped = latent
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
image = Image.open(image_path)
|
||||
@@ -1197,6 +1229,10 @@ class FineTuningDataset(BaseDataset):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
@@ -1237,10 +1273,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
|
||||
def cache_latents(self, vae):
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae)
|
||||
dataset.cache_latents(vae, vae_batch_size)
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
@@ -1986,6 +2022,7 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user