sync update to train_native

This commit is contained in:
Darren Laurie
2025-04-06 16:34:45 +08:00
parent 83b1ab640a
commit d932129093
2 changed files with 30 additions and 10 deletions

View File

@@ -78,14 +78,14 @@ class SdxlNativeTrainer(train_native.NativeTrainer):
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
#Disabled for 64 / 32 conflict. Has been checked below.
#super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
train_dataset_group.verify_bucket_reso_steps(self.arb_min_steps)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(self.arb_min_steps)

View File

@@ -63,6 +63,8 @@ class NativeTrainer:
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
# Assumed network_train_unet_only is False
@@ -70,8 +72,13 @@ class NativeTrainer:
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
@@ -1101,7 +1108,8 @@ class NativeTrainer:
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1126,7 +1134,8 @@ class NativeTrainer:
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}
subsets_metadata = []
@@ -1144,6 +1153,7 @@ class NativeTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}
image_dir_or_metadata_file = None
@@ -1428,6 +1438,12 @@ class NativeTrainer:
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# lora_flux exclusive
if hasattr(training_model, "update_grad_norms"):
training_model.update_grad_norms()
if hasattr(training_model, "update_norms"):
training_model.update_norms()
optimizer.step()
lr_scheduler.step()
@@ -1439,9 +1455,14 @@ class NativeTrainer:
args.scale_weight_norms, accelerator.device
)
# TODO: Multiple models
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1478,10 +1499,7 @@ class NativeTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
@@ -1493,7 +1511,9 @@ class NativeTrainer:
optimizer,
keys_scaled,
mean_norm,
maximum_norm
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
accelerator.log(logs, step=global_step)