mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add latent scaling/shifting
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from library import strategy_sd3, utils
|
from library import sd3_models, strategy_sd3, utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -25,7 +25,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample_prompts_te_outputs = None
|
self.sample_prompts_te_outputs = None
|
||||||
self.is_schnell: Optional[bool] = None
|
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group):
|
def assert_extra_args(self, args, train_dataset_group):
|
||||||
super().assert_extra_args(args, train_dataset_group)
|
super().assert_extra_args(args, train_dataset_group)
|
||||||
@@ -268,7 +267,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return vae.encode(images)
|
return vae.encode(images)
|
||||||
|
|
||||||
def shift_scale_latents(self, args, latents):
|
def shift_scale_latents(self, args, latents):
|
||||||
return latents
|
return sd3_models.SDVAE.process_in(latents)
|
||||||
|
|
||||||
def get_noise_pred_and_target(
|
def get_noise_pred_and_target(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user