mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
free pipe and cache after sample gen #260
This commit is contained in:
@@ -7,13 +7,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
import glob
|
import glob
|
||||||
@@ -214,24 +214,24 @@ class AugHelper:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
# prepare all possible augmentators
|
# prepare all possible augmentators
|
||||||
color_aug_method = albu.OneOf([
|
color_aug_method = albu.OneOf([
|
||||||
albu.HueSaturationValue(8, 0, 0, p=.5),
|
albu.HueSaturationValue(8, 0, 0, p=.5),
|
||||||
albu.RandomGamma((95, 105), p=.5),
|
albu.RandomGamma((95, 105), p=.5),
|
||||||
], p=.33)
|
], p=.33)
|
||||||
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
flip_aug_method = albu.HorizontalFlip(p=0.5)
|
||||||
|
|
||||||
# key: (use_color_aug, use_flip_aug)
|
# key: (use_color_aug, use_flip_aug)
|
||||||
self.augmentors = {
|
self.augmentors = {
|
||||||
(True, True): albu.Compose([
|
(True, True): albu.Compose([
|
||||||
color_aug_method,
|
color_aug_method,
|
||||||
flip_aug_method,
|
flip_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(True, False): albu.Compose([
|
(True, False): albu.Compose([
|
||||||
color_aug_method,
|
color_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(False, True): albu.Compose([
|
(False, True): albu.Compose([
|
||||||
flip_aug_method,
|
flip_aug_method,
|
||||||
], p=1.),
|
], p=1.),
|
||||||
(False, False): None
|
(False, False): None
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
def get_augmentor(self, use_color_aug: bool, use_flip_aug: bool) -> Optional[albu.Compose]:
|
||||||
@@ -260,7 +260,7 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||||
|
|
||||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||||
|
|
||||||
self.is_reg = is_reg
|
self.is_reg = is_reg
|
||||||
self.class_tokens = class_tokens
|
self.class_tokens = class_tokens
|
||||||
@@ -271,12 +271,13 @@ class DreamBoothSubset(BaseSubset):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.image_dir == other.image_dir
|
return self.image_dir == other.image_dir
|
||||||
|
|
||||||
|
|
||||||
class FineTuningSubset(BaseSubset):
|
class FineTuningSubset(BaseSubset):
|
||||||
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
def __init__(self, image_dir, metadata_file: str, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug, face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate) -> None:
|
||||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||||
|
|
||||||
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
super().__init__(image_dir, num_repeats, shuffle_caption, keep_tokens, color_aug, flip_aug,
|
||||||
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
face_crop_aug_range, random_crop, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate)
|
||||||
|
|
||||||
self.metadata_file = metadata_file
|
self.metadata_file = metadata_file
|
||||||
|
|
||||||
@@ -285,6 +286,7 @@ class FineTuningSubset(BaseSubset):
|
|||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.metadata_file == other.metadata_file
|
return self.metadata_file == other.metadata_file
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
def __init__(self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -804,7 +806,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
captions.append("")
|
captions.append("")
|
||||||
else:
|
else:
|
||||||
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||||
|
|
||||||
return img_paths, captions
|
return img_paths, captions
|
||||||
@@ -815,11 +817,13 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
reg_infos: List[ImageInfo] = []
|
reg_infos: List[ImageInfo] = []
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
if subset.num_repeats < 1:
|
if subset.num_repeats < 1:
|
||||||
print(f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
print(
|
||||||
|
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if subset in self.subsets:
|
if subset in self.subsets:
|
||||||
print(f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
print(
|
||||||
|
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
img_paths, captions = load_dreambooth_dir(subset)
|
img_paths, captions = load_dreambooth_dir(subset)
|
||||||
@@ -881,11 +885,13 @@ class FineTuningDataset(BaseDataset):
|
|||||||
|
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
if subset.num_repeats < 1:
|
if subset.num_repeats < 1:
|
||||||
print(f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
print(
|
||||||
|
f"ignore subset with metadata_file='{subset.metadata_file}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if subset in self.subsets:
|
if subset in self.subsets:
|
||||||
print(f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
print(
|
||||||
|
f"ignore duplicated subset with metadata_file='{subset.metadata_file}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# メタデータを読み込む
|
# メタデータを読み込む
|
||||||
@@ -937,7 +943,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.subsets.append(subset)
|
self.subsets.append(subset)
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
use_npz_latents = all([not(subset.color_aug or subset.random_crop) for subset in self.subsets])
|
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||||
if use_npz_latents:
|
if use_npz_latents:
|
||||||
flip_aug_in_subset = False
|
flip_aug_in_subset = False
|
||||||
npz_any = False
|
npz_any = False
|
||||||
@@ -2209,8 +2215,6 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# ここでCUDAのキャッシュクリアとかしたほうがいいのか……
|
|
||||||
|
|
||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
@@ -2346,7 +2350,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
@@ -2356,6 +2360,10 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# clear pipeline and cache to reduce vram usage
|
||||||
|
del pipeline
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user