mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove dependency for albumenations
This commit is contained in:
@@ -55,7 +55,6 @@ from diffusers import (
|
|||||||
from library import custom_train_functions
|
from library import custom_train_functions
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import albumentations as albu
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import cv2
|
import cv2
|
||||||
@@ -285,42 +284,40 @@ class BucketBatchIndex(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class AugHelper:
|
class AugHelper:
|
||||||
|
# albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# prepare all possible augmentators
|
pass
|
||||||
self.color_aug_method = albu.OneOf(
|
|
||||||
[
|
|
||||||
albu.HueSaturationValue(8, 0, 0, p=0.5),
|
|
||||||
albu.RandomGamma((95, 105), p=0.5),
|
|
||||||
],
|
|
||||||
p=0.33,
|
|
||||||
)
|
|
||||||
|
|
||||||
# key: (use_color_aug, use_flip_aug)
|
def color_aug(self, image: np.ndarray):
|
||||||
# self.augmentors = {
|
# self.color_aug_method = albu.OneOf(
|
||||||
# (True, True): albu.Compose(
|
# [
|
||||||
# [
|
# albu.HueSaturationValue(8, 0, 0, p=0.5),
|
||||||
# color_aug_method,
|
# albu.RandomGamma((95, 105), p=0.5),
|
||||||
# flip_aug_method,
|
# ],
|
||||||
# ],
|
# p=0.33,
|
||||||
# p=1.0,
|
# )
|
||||||
# ),
|
hue_shift_limit = 8
|
||||||
# (True, False): albu.Compose(
|
|
||||||
# [
|
|
||||||
# color_aug_method,
|
|
||||||
# ],
|
|
||||||
# p=1.0,
|
|
||||||
# ),
|
|
||||||
# (False, True): albu.Compose(
|
|
||||||
# [
|
|
||||||
# flip_aug_method,
|
|
||||||
# ],
|
|
||||||
# p=1.0,
|
|
||||||
# ),
|
|
||||||
# (False, False): None,
|
|
||||||
# }
|
|
||||||
|
|
||||||
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]:
|
# remove dependency to albumentations
|
||||||
return self.color_aug_method if use_color_aug else None
|
if random.random() <= 0.33:
|
||||||
|
if random.random() > 0.5:
|
||||||
|
# hue shift
|
||||||
|
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
|
||||||
|
if hue_shift < 0:
|
||||||
|
hue_shift = 180 + hue_shift
|
||||||
|
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
|
||||||
|
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
|
||||||
|
else:
|
||||||
|
# random gamma
|
||||||
|
gamma = random.uniform(0.95, 1.05)
|
||||||
|
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
|
||||||
|
return self.color_aug if use_color_aug else None
|
||||||
|
|
||||||
|
|
||||||
class BaseSubset:
|
class BaseSubset:
|
||||||
@@ -3443,7 +3440,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
|
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ accelerate==0.19.0
|
|||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
diffusers[torch]==0.18.2
|
diffusers[torch]==0.18.2
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
albumentations==1.3.0
|
# albumentations==1.3.0
|
||||||
opencv-python==4.7.0.68
|
opencv-python==4.7.0.68
|
||||||
einops==0.6.0
|
einops==0.6.0
|
||||||
pytorch-lightning==1.9.0
|
pytorch-lightning==1.9.0
|
||||||
|
|||||||
Reference in New Issue
Block a user