Update README and clean-up the code for SD3 timesteps

This commit is contained in:
Kohya S
2024-11-07 21:27:12 +09:00
parent 588ea9e123
commit 5e86323f12
6 changed files with 30 additions and 19 deletions

View File

@@ -275,9 +275,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# shift 3.0 is the default value
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# this scheduler is not used in training, but used to get num_train_timesteps etc.
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
@@ -304,7 +303,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, self.noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
args, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad