remove dependency for albumenations

This commit is contained in:
Kohya S
2023-07-30 16:29:53 +09:00
parent 496c3f2732
commit f61996b425
2 changed files with 33 additions and 36 deletions

View File

@@ -55,7 +55,6 @@ from diffusers import (
from library import custom_train_functions from library import custom_train_functions
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import cv2 import cv2
@@ -285,42 +284,40 @@ class BucketBatchIndex(NamedTuple):
class AugHelper: class AugHelper:
# albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
def __init__(self): def __init__(self):
# prepare all possible augmentators pass
self.color_aug_method = albu.OneOf(
[
albu.HueSaturationValue(8, 0, 0, p=0.5),
albu.RandomGamma((95, 105), p=0.5),
],
p=0.33,
)
# key: (use_color_aug, use_flip_aug) def color_aug(self, image: np.ndarray):
# self.augmentors = { # self.color_aug_method = albu.OneOf(
# (True, True): albu.Compose(
# [ # [
# color_aug_method, # albu.HueSaturationValue(8, 0, 0, p=0.5),
# flip_aug_method, # albu.RandomGamma((95, 105), p=0.5),
# ], # ],
# p=1.0, # p=0.33,
# ), # )
# (True, False): albu.Compose( hue_shift_limit = 8
# [
# color_aug_method,
# ],
# p=1.0,
# ),
# (False, True): albu.Compose(
# [
# flip_aug_method,
# ],
# p=1.0,
# ),
# (False, False): None,
# }
def get_augmentor(self, use_color_aug: bool) -> Optional[albu.Compose]: # remove dependency to albumentations
return self.color_aug_method if use_color_aug else None if random.random() <= 0.33:
if random.random() > 0.5:
# hue shift
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
if hue_shift < 0:
hue_shift = 180 + hue_shift
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
else:
# random gamma
gamma = random.uniform(0.95, 1.05)
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
return {"image": image}
def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
return self.color_aug if use_color_aug else None
class BaseSubset: class BaseSubset:

View File

@@ -2,7 +2,7 @@ accelerate==0.19.0
transformers==4.30.2 transformers==4.30.2
diffusers[torch]==0.18.2 diffusers[torch]==0.18.2
ftfy==6.1.1 ftfy==6.1.1
albumentations==1.3.0 # albumentations==1.3.0
opencv-python==4.7.0.68 opencv-python==4.7.0.68
einops==0.6.0 einops==0.6.0
pytorch-lightning==1.9.0 pytorch-lightning==1.9.0