mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Update train_network.py
This commit is contained in:
@@ -178,8 +178,7 @@ class NetworkTrainer:
|
|||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device)
|
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
||||||
timesteps = timesteps.long()
|
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
|
|||||||
Reference in New Issue
Block a user