From d932129093e8514dafdbb467987d6544ea12382c Mon Sep 17 00:00:00 2001 From: Darren Laurie <6DammK9@gmail.com> Date: Sun, 6 Apr 2025 16:34:45 +0800 Subject: [PATCH] sync update to train_native --- sdxl_train.py | 4 ++-- train_native.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index b1dadfbc..8537a99b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -78,14 +78,14 @@ class SdxlNativeTrainer(train_native.NativeTrainer): train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): - super().assert_extra_args(args, train_dataset_group, val_dataset_group) + #Disabled for 64 / 32 conflict. Has been checked below. + #super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: assert ( train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - train_dataset_group.verify_bucket_reso_steps(self.arb_min_steps) if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(self.arb_min_steps) diff --git a/train_native.py b/train_native.py index a5062e01..42e838da 100644 --- a/train_native.py +++ b/train_native.py @@ -63,6 +63,8 @@ class NativeTrainer: keys_scaled=None, mean_norm=None, maximum_norm=None, + mean_grad_norm=None, + mean_combined_norm=None, ): # Assumed network_train_unet_only is False @@ -70,8 +72,13 @@ class NativeTrainer: if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/average_key_norm"] = mean_norm logs["max_norm/max_key_norm"] = maximum_norm + if mean_norm is not None: + logs["norm/avg_key_norm"] = mean_norm + if mean_grad_norm is not None: + logs["norm/avg_grad_norm"] = mean_grad_norm + if mean_combined_norm is not None: + logs["norm/avg_combined_norm"] = mean_combined_norm lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): @@ -1101,7 +1108,8 @@ class NativeTrainer: "ss_validation_split": args.validation_split, "ss_max_validation_steps": args.max_validation_steps, "ss_validate_every_n_epochs": args.validate_every_n_epochs, - "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_validate_every_n_steps": args.validate_every_n_steps, + "ss_resize_interpolation": args.resize_interpolation, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1126,7 +1134,8 @@ class NativeTrainer: "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, "tag_frequency": dataset.tag_frequency, - "bucket_info": dataset.bucket_info, + "bucket_info": dataset.bucket_info, + "resize_interpolation": dataset.resize_interpolation, } subsets_metadata = [] @@ -1144,6 +1153,7 @@ class NativeTrainer: "enable_wildcard": bool(subset.enable_wildcard), "caption_prefix": subset.caption_prefix, "caption_suffix": subset.caption_suffix, + "resize_interpolation": subset.resize_interpolation, } image_dir_or_metadata_file = None @@ -1428,6 +1438,12 @@ class NativeTrainer: if (args.max_grad_norm != 0.0) and hasattr(training_model, "get_trainable_params"): params_to_clip = accelerator.unwrap_model(training_model).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + # lora_flux exclusive + if hasattr(training_model, "update_grad_norms"): + training_model.update_grad_norms() + if hasattr(training_model, "update_norms"): + training_model.update_norms() optimizer.step() lr_scheduler.step() @@ -1439,9 +1455,14 @@ class NativeTrainer: args.scale_weight_norms, accelerator.device ) # TODO: Multiple models + mean_grad_norm = None + mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: keys_scaled, mean_norm, maximum_norm = None, None, None + mean_grad_norm = None + mean_combined_norm = None + max_mean_logs = {} # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1478,10 +1499,7 @@ class NativeTrainer: loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.scale_weight_norms: - progress_bar.set_postfix(**{**max_mean_logs, **logs}) + progress_bar.set_postfix(**{**max_mean_logs, **logs}) if is_tracking: logs = self.generate_step_logs( @@ -1493,7 +1511,9 @@ class NativeTrainer: optimizer, keys_scaled, mean_norm, - maximum_norm + maximum_norm, + mean_grad_norm, + mean_combined_norm, ) accelerator.log(logs, step=global_step)