Typos and lingering is_train

This commit is contained in:
rockerBOO
2025-01-03 01:18:15 -05:00
parent 7470173044
commit 534059dea5
3 changed files with 4 additions and 8 deletions

View File

@@ -380,11 +380,11 @@ class NetworkTrainer:
else:
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
choosen_timesteps_list = pick_timesteps_list()
chosen_timesteps_list = pick_timesteps_list()
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
# Use input timesteps_list or use described timesteps above
for fixed_timestep in choosen_timesteps_list:
for fixed_timestep in chosen_timesteps_list:
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
# Predict the noise residual
@@ -447,7 +447,7 @@ class NetworkTrainer:
total_loss += loss
return total_loss / len(choosen_timesteps_list)
return total_loss / len(chosen_timesteps_list)
def train(self, args):
session_id = random.randint(0, 2**32)