mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
cache latents to disk in dreambooth method
This commit is contained in:
@@ -142,12 +142,14 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
training_models = []
|
training_models = []
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
|
|||||||
@@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
def is_latent_cacheable(self):
|
def is_latent_cacheable(self):
|
||||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
# ちょっと速くした
|
# ちょっと速くした
|
||||||
print("caching latents.")
|
print("caching latents.")
|
||||||
|
|
||||||
@@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if info.latents_npz is not None:
|
if info.latents_npz is not None:
|
||||||
info.latents = self.load_latents_from_npz(info, False)
|
info.latents = self.load_latents_from_npz(info, False)
|
||||||
info.latents = torch.FloatTensor(info.latents)
|
info.latents = torch.FloatTensor(info.latents)
|
||||||
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
|
||||||
|
# might be None, but that's ok because check is done in dataset
|
||||||
|
info.latents_flipped = self.load_latents_from_npz(info, True)
|
||||||
if info.latents_flipped is not None:
|
if info.latents_flipped is not None:
|
||||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# check disk cache exists and size of latents
|
||||||
|
if cache_to_disk:
|
||||||
|
# TODO: refactor to unify with FineTuningDataset
|
||||||
|
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||||
|
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
||||||
|
if not is_main_process:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cache_available = False
|
||||||
|
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
|
||||||
|
if os.path.exists(info.latents_npz):
|
||||||
|
cached_latents = np.load(info.latents_npz)
|
||||||
|
if cached_latents["latents"].shape[1:3] == expected_latents_size:
|
||||||
|
cache_available = True
|
||||||
|
|
||||||
|
if subset.flip_aug:
|
||||||
|
cache_available = False
|
||||||
|
if os.path.exists(info.latents_npz_flipped):
|
||||||
|
cached_latents_flipped = np.load(info.latents_npz_flipped)
|
||||||
|
if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size:
|
||||||
|
cache_available = True
|
||||||
|
|
||||||
|
if cache_available:
|
||||||
|
continue
|
||||||
|
|
||||||
# if last member of batch has different resolution, flush the batch
|
# if last member of batch has different resolution, flush the batch
|
||||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
@@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if len(batch) > 0:
|
if len(batch) > 0:
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
|
|
||||||
|
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
|
||||||
|
return
|
||||||
|
|
||||||
# iterate batches
|
# iterate batches
|
||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
images = []
|
images = []
|
||||||
@@ -773,14 +803,21 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||||
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
|
|
||||||
for info, latent in zip(batch, latents):
|
for info, latent in zip(batch, latents):
|
||||||
info.latents = latent
|
if cache_to_disk:
|
||||||
|
np.savez(info.latents_npz, latent.float().numpy())
|
||||||
|
else:
|
||||||
|
info.latents = latent
|
||||||
|
|
||||||
if subset.flip_aug:
|
if subset.flip_aug:
|
||||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
for info, latent in zip(batch, latents):
|
for info, latent in zip(batch, latents):
|
||||||
info.latents_flipped = latent
|
if cache_to_disk:
|
||||||
|
np.savez(info.latents_npz_flipped, latent.float().numpy())
|
||||||
|
else:
|
||||||
|
info.latents_flipped = latent
|
||||||
|
|
||||||
def get_image_size(self, image_path):
|
def get_image_size(self, image_path):
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
@@ -873,10 +910,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||||
|
|
||||||
# image/latentsを処理する
|
# image/latentsを処理する
|
||||||
if image_info.latents is not None:
|
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None:
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||||
latents = torch.FloatTensor(latents)
|
latents = torch.FloatTensor(latents)
|
||||||
image = None
|
image = None
|
||||||
@@ -1340,10 +1377,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
dataset.enable_XTI(*args, **kwargs)
|
dataset.enable_XTI(*args, **kwargs)
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
for i, dataset in enumerate(self.datasets):
|
for i, dataset in enumerate(self.datasets):
|
||||||
print(f"[Dataset {i}]")
|
print(f"[Dataset {i}]")
|
||||||
dataset.cache_latents(vae, vae_batch_size)
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||||
@@ -2144,9 +2181,14 @@ def add_dataset_arguments(
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_latents",
|
"--cache_latents",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
|
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ",
|
||||||
)
|
)
|
||||||
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
|
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_latents_to_disk",
|
||||||
|
action="store_true",
|
||||||
|
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
|
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -117,12 +117,14 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||||
unet.requires_grad_(True) # 念のため追加
|
unet.requires_grad_(True) # 念のため追加
|
||||||
|
|||||||
@@ -172,12 +172,14 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# prepare network
|
# prepare network
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -233,12 +233,14 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|||||||
@@ -267,12 +267,14 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|||||||
Reference in New Issue
Block a user