diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6428ef8a..74f3852d 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2569,10 +2569,12 @@ def main(args): # Gradual Latent if args.gradual_latent_timesteps is not None: if args.gradual_latent_unsharp_params: - ksize, sigma, strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] - ksize = int(ksize) + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( args.gradual_latent_ratio, @@ -2580,9 +2582,10 @@ def main(args): args.gradual_latent_every_n_steps, args.gradual_latent_ratio_step, args.gradual_latent_s_noise, - ksize, - sigma, - strength, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3348,12 +3351,23 @@ def main(args): if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 if gl_unsharp_params is not None: - ksize, sigma, strength = [float(v) for v in gl_unsharp_params.split(",")] - ksize = int(ksize) + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + print(unsharp_params) + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) + us_ksize = int(us_ksize) else: - ksize, sigma, strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( - gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, ksize, sigma, strength + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3765,8 +3779,8 @@ def setup_parser() -> argparse.ArgumentParser: "--gradual_latent_unsharp_params", type=str, default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) return parser diff --git a/library/utils.py b/library/utils.py index 1ccb141c..48fb40e2 100644 --- a/library/utils.py +++ b/library/utils.py @@ -28,6 +28,7 @@ class GradualLatent: gaussian_blur_ksize=None, gaussian_blur_sigma=0.5, gaussian_blur_strength=0.5, + unsharp_target_x=True, ): self.ratio = ratio self.start_timesteps = start_timesteps @@ -37,12 +38,14 @@ class GradualLatent: self.gaussian_blur_ksize = gaussian_blur_ksize self.gaussian_blur_sigma = gaussian_blur_sigma self.gaussian_blur_strength = gaussian_blur_strength + self.unsharp_target_x = unsharp_target_x def __str__(self) -> str: return ( f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " - + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength})" + + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + + f"unsharp_target_x={self.unsharp_target_x})" ) def apply_unshark_mask(self, x: torch.Tensor): @@ -54,6 +57,19 @@ class GradualLatent: sharpened = x + mask return sharpened + def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.float() + + x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) + + # apply unsharp mask / アンシャープマスクを適用する + if unsharp and self.gaussian_blur_ksize: + x = self.apply_unshark_mask(x) + + return x + class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): def __init__(self, *args, **kwargs): @@ -140,29 +156,25 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): dt = sigma_down - sigma - prev_sample = sample + derivative * dt - device = model_output.device if self.resized_size is None: + prev_sample = sample + derivative * dt + noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( model_output.shape, dtype=model_output.dtype, device=device, generator=generator ) s_noise = 1.0 else: - print( - "resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape - ) - org_dtype = prev_sample.dtype - if org_dtype == torch.bfloat16: - prev_sample = prev_sample.float() + print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) + s_noise = self.gradual_latent.s_noise - prev_sample = torch.nn.functional.interpolate( - prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False - ).to(dtype=org_dtype) - - # apply unsharp mask / アンシャープマスクを適用する - if self.gradual_latent.gaussian_blur_ksize: - prev_sample = self.gradual_latent.apply_unshark_mask(prev_sample) + if self.gradual_latent.unsharp_target_x: + prev_sample = sample + derivative * dt + prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) + else: + sample = self.gradual_latent.interpolate(sample, self.resized_size) + derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) + prev_sample = sample + derivative * dt noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), @@ -170,7 +182,6 @@ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): device=device, generator=generator, ) - s_noise = self.gradual_latent.s_noise prev_sample = prev_sample + noise * sigma_up * s_noise diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 35d6575e..bfe2e512 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1829,10 +1829,12 @@ def main(args): # Gradual Latent if args.gradual_latent_timesteps is not None: if args.gradual_latent_unsharp_params: - us_ksize, us_sigma, us_strength = [float(v) for v in args.gradual_latent_unsharp_params.split(",")] + us_params = args.gradual_latent_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]] + us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3])) us_ksize = int(us_ksize) else: - us_ksize, us_sigma, us_strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( args.gradual_latent_ratio, @@ -1843,6 +1845,7 @@ def main(args): us_ksize, us_sigma, us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -2650,12 +2653,22 @@ def main(args): if gl_timesteps < 0: gl_timesteps = args.gradual_latent_timesteps or 650 if gl_unsharp_params is not None: - us_ksize, us_sigma, us_strength = [float(v) for v in gl_unsharp_params.split(",")] + unsharp_params = gl_unsharp_params.split(",") + us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]] + us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3])) us_ksize = int(us_ksize) else: - us_ksize, us_sigma, us_strength = None, None, None + us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None gradual_latent = GradualLatent( - gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step, gl_s_noise, us_ksize, us_sigma, us_strength + gl_ratio, + gl_timesteps, + gl_every_n_steps, + gl_ratio_step, + gl_s_noise, + us_ksize, + us_sigma, + us_strength, + us_target_x, ) pipe.set_gradual_latent(gradual_latent) @@ -3056,8 +3069,8 @@ def setup_parser() -> argparse.ArgumentParser: "--gradual_latent_unsharp_params", type=str, default=None, - help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength. `3,0.5,0.5` is recommended /" - + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength. `3,0.5,0.5` が推奨", + help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /" + + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) # # parser.add_argument(