From 344845b42941b48956dce94d614fbf32e900c70e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Feb 2025 21:25:40 +0900 Subject: [PATCH] fix: validation with block swap --- flux_train_network.py | 14 ++++++++++++-- sd3_train_network.py | 19 ++++++++++++++----- train_network.py | 18 +++++++++++------- 3 files changed, 37 insertions(+), 14 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 475bd751..e97dfc5b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -507,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/sd3_train_network.py b/sd3_train_network.py index d4f13125..216d93c5 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + def assert_extra_args( + self, + args, + train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], + val_dataset_group: Optional[train_util.DatasetGroup], + ): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, - is_train=True + is_train=True, ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -445,15 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): text_encoder.to(te_weight_dtype) # fp8 prepare_fp8(text_encoder, weight_dtype) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # TODO consider validation - # drop cached text encoder outputs + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True): + # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) batch["text_encoder_outputs_list"] = text_encoder_outputs_list + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + if self.is_swapping_blocks: + # prepare for next forward: because backward pass is not called, we need to prepare it here + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() + def prepare_unet_with_accelerator( self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module ) -> torch.nn.Module: diff --git a/train_network.py b/train_network.py index 083e5993..49013c70 100644 --- a/train_network.py +++ b/train_network.py @@ -309,7 +309,10 @@ class NetworkTrainer: ) -> torch.nn.Module: return accelerator.prepare(unet) - def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): + def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True): + pass + + def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): pass # endregion @@ -1278,7 +1281,7 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1330,8 +1333,8 @@ class NetworkTrainer: with accelerator.accumulate(training_model): on_step_start_for_network(text_encoder, unet) - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + # preprocess batch for each model + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) loss = self.process_batch( batch, @@ -1434,8 +1437,7 @@ class NetworkTrainer: break for timestep in validation_timesteps: - # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep @@ -1471,6 +1473,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: @@ -1516,7 +1519,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # temporary, for batch processing - self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) loss = self.process_batch( batch, @@ -1551,6 +1554,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) + self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype) val_ts_step += 1 if is_tracking: