mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Typos and lingering is_train
This commit is contained in:
@@ -380,11 +380,11 @@ class NetworkTrainer:
|
||||
else:
|
||||
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
|
||||
|
||||
choosen_timesteps_list = pick_timesteps_list()
|
||||
chosen_timesteps_list = pick_timesteps_list()
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timestep in choosen_timesteps_list:
|
||||
for fixed_timestep in chosen_timesteps_list:
|
||||
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
|
||||
|
||||
# Predict the noise residual
|
||||
@@ -447,7 +447,7 @@ class NetworkTrainer:
|
||||
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(choosen_timesteps_list)
|
||||
return total_loss / len(chosen_timesteps_list)
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
|
||||
Reference in New Issue
Block a user