mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code
This commit is contained in:
@@ -327,8 +327,8 @@ class NetworkTrainer:
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_sd.SdTextEncodingStrategy,
|
||||
tokenize_strategy: strategy_sd.SdTokenizeStrategy,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True
|
||||
@@ -1183,17 +1183,7 @@ class NetworkTrainer:
|
||||
|
||||
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
train_util.init_trackers(accelerator, args, "network_train")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
@@ -1386,15 +1376,14 @@ class NetworkTrainer:
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_epoch = (
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if validation_steps > 0 and should_validate_epoch:
|
||||
if validation_steps > 0 and should_validate_step:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
@@ -1406,6 +1395,9 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
@@ -1428,18 +1420,22 @@ class NetworkTrainer:
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/step_validation_current": current_loss}
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
logs = {
|
||||
"loss/validation/step/current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step/average": val_step_loss_recorder.moving_average,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# VALIDATION EPOCH
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
@@ -1458,6 +1454,9 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
@@ -1480,21 +1479,22 @@ class NetworkTrainer:
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_validation_current": current_loss}
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
logs = {"loss/epoch_validation_average": avr_loss}
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user