add latent scaling/shifting

This commit is contained in:
kohya-ss
2024-10-25 23:20:38 +09:00
parent d2c549d7b2
commit 0031d916f0

View File

@@ -6,7 +6,7 @@ from typing import Any, Optional
import torch
from accelerate import Accelerator
from library import strategy_sd3, utils
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
@@ -268,7 +267,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
return sd3_models.SDVAE.process_in(latents)
def get_noise_pred_and_target(
self,