mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Merge 8b0a467bc0 into 2e0fcc50cb
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -528,7 +528,6 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
|
||||
@@ -4660,6 +4660,27 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
|
||||
if section_name == "wavelet_loss_band_level_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_level_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_quaternion_component_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
@@ -509,6 +509,26 @@ def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
|
||||
# Debugging tool for saving latent as image
|
||||
def save_latent_as_img(vae, latent_to: torch.Tensor, output_name: str):
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_to.to(vae.dtype)).float()
|
||||
# VAE outputs are typically in the range [-1, 1], so rescale to [0, 255]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# Convert to numpy array with values in range [0, 255]
|
||||
image = (image * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
# Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels]
|
||||
image = image.transpose(0, 2, 3, 1)
|
||||
|
||||
# Take the first image if you have a batch
|
||||
pil_image = Image.fromarray(image[0])
|
||||
|
||||
# Save the image
|
||||
pil_image.save(output_name)
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
Reference in New Issue
Block a user