mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update README and clean-up the code for SD3 timesteps
This commit is contained in:
@@ -275,9 +275,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
# shift 3.0 is the default value
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
# this scheduler is not used in training, but used to get num_train_timesteps etc.
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
@@ -304,7 +303,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
|
||||
Reference in New Issue
Block a user