mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix pynoise code bug (#489)
* fix pynoise * Update custom_train_functions.py for default * Update custom_train_functions.py for note * Update custom_train_functions.py for default * Revert "Update custom_train_functions.py for default" This reverts commitca79915d73. * Update custom_train_functions.py for default * Revert "Update custom_train_functions.py for default" This reverts commit483577e137. * default value change
This commit is contained in:
@@ -346,14 +346,14 @@ def get_weighted_text_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
|
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||||
b, c, w, h = noise.shape
|
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||||||
w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
||||||
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
|
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
||||||
if w == 1 or h == 1:
|
if wn == 1 or hn == 1:
|
||||||
break # Lowest resolution is 1x1
|
break # Lowest resolution is 1x1
|
||||||
return noise / noise.std() # Scaled back to roughly unit variance
|
return noise / noise.std() # Scaled back to roughly unit variance
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user