From 63992b81c840ea42b53d70d611ef27ff85ae397e Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 21:13:29 +0900 Subject: [PATCH] Fix initialize place of loss_recorder --- fine_tune.py | 2 +- sdxl_train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 27d64739..afec7d27 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -288,6 +288,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -295,7 +296,6 @@ def train(args): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく diff --git a/sdxl_train.py b/sdxl_train.py index 9017d7b8..f681f28f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -452,6 +452,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -459,7 +460,6 @@ def train(args): for m in training_models: m.train() - loss_recorder = train_util.LossRecorder() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく