mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add latent scaling/shifting
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from library import strategy_sd3, utils
|
||||
from library import sd3_models, strategy_sd3, utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
@@ -268,7 +267,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
return sd3_models.SDVAE.process_in(latents)
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user