mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add target_x flag (not sure this impl is correct)
This commit is contained in:
@@ -28,6 +28,7 @@ class GradualLatent:
|
||||
gaussian_blur_ksize=None,
|
||||
gaussian_blur_sigma=0.5,
|
||||
gaussian_blur_strength=0.5,
|
||||
unsharp_target_x=True,
|
||||
):
|
||||
self.ratio = ratio
|
||||
self.start_timesteps = start_timesteps
|
||||
@@ -37,12 +38,14 @@ class GradualLatent:
|
||||
self.gaussian_blur_ksize = gaussian_blur_ksize
|
||||
self.gaussian_blur_sigma = gaussian_blur_sigma
|
||||
self.gaussian_blur_strength = gaussian_blur_strength
|
||||
self.unsharp_target_x = unsharp_target_x
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
||||
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
||||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength})"
|
||||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
||||
+ f"unsharp_target_x={self.unsharp_target_x})"
|
||||
)
|
||||
|
||||
def apply_unshark_mask(self, x: torch.Tensor):
|
||||
@@ -54,6 +57,19 @@ class GradualLatent:
|
||||
sharpened = x + mask
|
||||
return sharpened
|
||||
|
||||
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
|
||||
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if unsharp and self.gaussian_blur_ksize:
|
||||
x = self.apply_unshark_mask(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -140,29 +156,25 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
|
||||
dt = sigma_down - sigma
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
device = model_output.device
|
||||
if self.resized_size is None:
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
||||
)
|
||||
s_noise = 1.0
|
||||
else:
|
||||
print(
|
||||
"resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape
|
||||
)
|
||||
org_dtype = prev_sample.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
prev_sample = prev_sample.float()
|
||||
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
||||
s_noise = self.gradual_latent.s_noise
|
||||
|
||||
prev_sample = torch.nn.functional.interpolate(
|
||||
prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False
|
||||
).to(dtype=org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if self.gradual_latent.gaussian_blur_ksize:
|
||||
prev_sample = self.gradual_latent.apply_unshark_mask(prev_sample)
|
||||
if self.gradual_latent.unsharp_target_x:
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
||||
else:
|
||||
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
||||
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
||||
@@ -170,7 +182,6 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
s_noise = self.gradual_latent.s_noise
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up * s_noise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user