From 8d5ba2936371cc9751b2fe1a0cbe07cb87aa9ffd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 7 Mar 2023 08:06:36 +0900 Subject: [PATCH] free pipe and cache after sample gen #260 --- library/train_util.py | 70 ++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 31 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 75176e13..e15ce133 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -7,13 +7,13 @@ import re import shutil import time from typing import ( - Dict, - List, - NamedTuple, - Optional, - Sequence, - Tuple, - Union, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, ) from accelerate import Accelerator import glob @@ -214,24 +214,24 @@ class AugHelper: def __init__(self): # prepare all possible augmentators color_aug_method = albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), ], p=.33) flip_aug_method = albu.HorizontalFlip(p=0.5) # key: (use_color_aug, use_flip_aug) self.augmentors = { - (True, True): albu.Compose([ - color_aug_method, - flip_aug_method, - ], p=1.), - (True, False): albu.Compose([ - color_aug_method, - ], p=1.), - (False, True): albu.Compose([ - flip_aug_method, - ], p=1.), - (False, False): None + (True, True): albu.Compose([ + color_aug_method, + flip_aug_method, + ], p=1.), + (True, False): albu.Compose([ + color_aug_method, + ], p=1.), + (False, True): albu.Compose([ + flip_aug_method, + ], p=1.), + (False, False): None } def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]: @@ -260,7 +260,7 @@ class DreamBoothSubset(BaseSubset): assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) self.is_reg = is_reg self.class_tokens = class_tokens @@ -271,12 +271,13 @@ class DreamBoothSubset(BaseSubset): return NotImplemented return self.image_dir == other.image_dir + class FineTuningSubset(BaseSubset): def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, - face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) + face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) self.metadata_file = metadata_file @@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset): return NotImplemented return self.metadata_file == other.metadata_file + class BaseDataset(torch.utils.data.Dataset): def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None: super().__init__() @@ -804,7 +806,7 @@ class DreamBoothDataset(BaseDataset): captions.append("") else: captions.append(subset.class_tokens if cap_for_img is None else cap_for_img) - + self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録 return img_paths, captions @@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset): reg_infos: List[ImageInfo] = [] for subset in subsets: if subset.num_repeats < 1: - print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + print( + f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") continue if subset in self.subsets: - print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + print( + f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") continue img_paths, captions = load_dreambooth_dir(subset) @@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset): for subset in subsets: if subset.num_repeats < 1: - print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") + print( + f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}") continue if subset in self.subsets: - print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") + print( + f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します") continue # メタデータを読み込む @@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset): self.subsets.append(subset) # check existence of all npz files - use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets]) + use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets]) if use_npz_latents: flip_aug_in_subset = False npz_any = False @@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return - # ここでCUDAのキャッシュクリアとかしたほうがいいのか…… - org_vae_device = vae.device # CPUにいるはず vae.to(device) @@ -2346,7 +2350,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - + image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) @@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v image.save(os.path.join(save_dir, img_filename)) + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device)