mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
free pipe and cache after sample gen #260
This commit is contained in:
@@ -271,6 +271,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
return NotImplemented
|
||||
return self.image_dir == other.image_dir
|
||||
|
||||
|
||||
class FineTuningSubset(BaseSubset):
|
||||
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
@@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset):
|
||||
return NotImplemented
|
||||
return self.metadata_file == other.metadata_file
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
||||
super().__init__()
|
||||
@@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset):
|
||||
reg_infos: List[ImageInfo] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
print(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
print(
|
||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
continue
|
||||
|
||||
img_paths, captions = load_dreambooth_dir(subset)
|
||||
@@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
print(
|
||||
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
print(
|
||||
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||
continue
|
||||
|
||||
# メタデータを読み込む
|
||||
@@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset):
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
@@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
@@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
Reference in New Issue
Block a user