apply unsharp mask

This commit is contained in:
Kohya S
2023-11-27 23:50:21 +09:00
parent 298c6c2343
commit 2c50ea0403

View File

@@ -802,6 +802,10 @@ class PipelineLike:
latents, scale_factor=current_ratio, mode="bicubic", align_corners=False latents, scale_factor=current_ratio, mode="bicubic", align_corners=False
).to(org_dtype) ).to(org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(latents)
latents = latents + (latents - blurred) * 0.5
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
resized_size = None resized_size = None
if enable_gradual_latent: if enable_gradual_latent:
@@ -1434,8 +1438,9 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.resized_size = None self.resized_size = None
def set_resized_size(self, size): def set_resized_size(self, size, s_noise=0.5):
self.resized_size = size self.resized_size = size
self.s_noise = s_noise
def step( def step(
self, self,
@@ -1518,6 +1523,7 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator model_output.shape, dtype=model_output.dtype, device=device, generator=generator
) )
s_noise = 1.0
else: else:
print( print(
"resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape
@@ -1530,14 +1536,19 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False
).to(dtype=org_dtype) ).to(dtype=org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
blurred = torchvision.transforms.transforms.GaussianBlur(3, sigma=(0.5, 0.5))(prev_sample)
prev_sample = prev_sample + (prev_sample - blurred) * 0.5
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype, dtype=model_output.dtype,
device=device, device=device,
generator=generator, generator=generator,
) )
s_noise = self.s_noise
prev_sample = prev_sample + noise * sigma_up prev_sample = prev_sample + noise * sigma_up * s_noise
# upon completion increase step index by one # upon completion increase step index by one
self._step_index += 1 self._step_index += 1