From 1231f5114ccd6a0a26a53da82b89083299ccc333 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 7 Jan 2025 22:31:41 -0500 Subject: [PATCH] Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code --- library/train_util.py | 70 ++++++++++++++++--------------------------- train_network.py | 66 ++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0907a8c0..b8894752 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common( huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True) -def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor: +def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor: timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - return timesteps - - - - -def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: - """ - Apply noise modifications like noise offset and multires noise - """ - if args.noise_offset: - if args.noise_offset_random_strength: - noise_offset = torch.rand(1, device=latents.device) * args.noise_offset - else: - noise_offset = args.noise_offset - noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale) - if args.multires_noise_iterations: - noise = custom_train_functions.pyramid_noise_like( - noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount - ) - return noise - - -def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor: - """ - Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise` - """ - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - noise = modify_noise(args, noise, latents) - - return typing.cast(torch.FloatTensor, noise) - - -def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor: - """ - From args, produce random timesteps for each image in the batch - """ - min_timestep = 0 if args.min_timestep is None else args.min_timestep - max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep - - # Sample a random timestep for each image - timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device) - + timesteps = timesteps.long().to(device) return timesteps @@ -6457,6 +6415,30 @@ def sample_image_inference( wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption +def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str): + """ + Initialize experiment trackers with tracker specific behaviors + """ + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + default_tracker_name if args.log_tracker_name is None else args.log_tracker_name, + config=get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) + + if "wandb" in [tracker.name for tracker in accelerator.trackers]: + import wandb + wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) + + # Define specific metrics to handle validation and epochs "steps" + wandb_tracker.define_metric("epoch", hidden=True) + wandb_tracker.define_metric("val_step", hidden=True) + # endregion diff --git a/train_network.py b/train_network.py index d0596fca..199f589b 100644 --- a/train_network.py +++ b/train_network.py @@ -327,8 +327,8 @@ class NetworkTrainer: weight_dtype, accelerator, args, - text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, - tokenize_strategy: strategy_sd.SdTokenizeStrategy, + text_encoding_strategy: strategy_base.TextEncodingStrategy, + tokenize_strategy: strategy_base.TokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True @@ -1183,17 +1183,7 @@ class NetworkTrainer: noise_scheduler = self.get_noise_scheduler(args, accelerator.device) - if accelerator.is_main_process: - init_kwargs = {} - if args.wandb_run_name: - init_kwargs["wandb"] = {"name": args.wandb_run_name} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, - config=train_util.get_sanitized_config_or_none(args), - init_kwargs=init_kwargs, - ) + train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() @@ -1386,15 +1376,14 @@ class NetworkTrainer: mean_norm, maximum_norm ) - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + accelerator.log(logs, step=global_step) # VALIDATION PER STEP - should_validate_epoch = ( + should_validate_step = ( args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0 ) - if validation_steps > 0 and should_validate_epoch: + if validation_steps > 0 and should_validate_step: accelerator.print("Validating バリデーション処理...") val_progress_bar = tqdm( @@ -1406,6 +1395,9 @@ class NetworkTrainer: if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1428,18 +1420,22 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/step_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/step/current": current_loss, + "val_step": (epoch * validation_steps) + val_step, + } + accelerator.log(logs, step=global_step) - logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average} - # accelerator.log(logs, step=global_step) - accelerator.log(logs) + if is_tracking: + logs = { + "loss/validation/step/average": val_step_loss_recorder.moving_average, + } + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - # VALIDATION EPOCH + # EPOCH VALIDATION should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None @@ -1458,6 +1454,9 @@ class NetworkTrainer: if val_step >= validation_steps: break + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch( batch, text_encoders, @@ -1480,21 +1479,22 @@ class NetworkTrainer: val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) if is_tracking: - logs = {"loss/epoch_validation_current": current_loss} - # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) - accelerator.log(logs) + logs = { + "loss/validation/epoch_current": current_loss, + "epoch": epoch + 1, + "val_step": (epoch * validation_steps) + val_step + } + accelerator.log(logs, step=global_step) if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average - logs = {"loss/epoch_validation_average": avr_loss} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) # END OF EPOCH if is_tracking: - logs = {"loss/epoch_average": loss_recorder.moving_average} - # accelerator.log(logs, step=epoch + 1) - accelerator.log(logs) + logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} + accelerator.log(logs, step=global_step) accelerator.wait_for_everyone()