diff --git a/train_network.py b/train_network.py index 938e4193..e10c17c0 100644 --- a/train_network.py +++ b/train_network.py @@ -192,7 +192,7 @@ class NetworkTrainer: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし total_loss += loss - + average_loss = total_loss / len(timesteps_list) return average_loss