mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
sync update to train_native
This commit is contained in:
@@ -78,14 +78,14 @@ class SdxlNativeTrainer(train_native.NativeTrainer):
|
||||
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
#Disabled for 64 / 32 conflict. Has been checked below.
|
||||
#super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(self.arb_min_steps)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(self.arb_min_steps)
|
||||
|
||||
@@ -63,6 +63,8 @@ class NativeTrainer:
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None,
|
||||
):
|
||||
# Assumed network_train_unet_only is False
|
||||
|
||||
@@ -70,8 +72,13 @@ class NativeTrainer:
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
@@ -1101,7 +1108,8 @@ class NativeTrainer:
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1126,7 +1134,8 @@ class NativeTrainer:
|
||||
"min_bucket_reso": dataset.min_bucket_reso,
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
}
|
||||
|
||||
subsets_metadata = []
|
||||
@@ -1144,6 +1153,7 @@ class NativeTrainer:
|
||||
"enable_wildcard": bool(subset.enable_wildcard),
|
||||
"caption_prefix": subset.caption_prefix,
|
||||
"caption_suffix": subset.caption_suffix,
|
||||
"resize_interpolation": subset.resize_interpolation,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
@@ -1428,6 +1438,12 @@ class NativeTrainer:
|
||||
if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"):
|
||||
params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
# lora_flux exclusive
|
||||
if hasattr(training_model, "update_grad_norms"):
|
||||
training_model.update_grad_norms()
|
||||
if hasattr(training_model, "update_norms"):
|
||||
training_model.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -1439,9 +1455,14 @@ class NativeTrainer:
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
# TODO: Multiple models
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1478,10 +1499,7 @@ class NativeTrainer:
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
@@ -1493,7 +1511,9 @@ class NativeTrainer:
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user