fix: validation with block swap

This commit is contained in:
Kohya S
2025-02-09 21:25:40 +09:00
parent 0911683717
commit 344845b429
3 changed files with 37 additions and 14 deletions

View File

@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group) super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args) # sdxl_train_util.verify_sdxl_training_args(args)
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
@@ -507,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__() super().__init__()
self.sample_prompts_te_outputs = None self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
# super().assert_extra_args(args, train_dataset_group) # super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args) # sdxl_train_util.verify_sdxl_training_args(args)
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
@@ -445,15 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# TODO consider validation # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -309,7 +309,10 @@ class NetworkTrainer:
) -> torch.nn.Module: ) -> torch.nn.Module:
return accelerator.prepare(unet) return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass pass
# endregion # endregion
@@ -1278,7 +1281,7 @@ class NetworkTrainer:
original_args_min_timestep = args.min_timestep original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep original_args_max_timestep = args.max_timestep
def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state() cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda": if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state() gpu_rng_state = torch.cuda.get_rng_state()
@@ -1330,8 +1333,8 @@ class NetworkTrainer:
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet) on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing # preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
@@ -1434,8 +1437,7 @@ class NetworkTrainer:
break break
for timestep in validation_timesteps: for timestep in validation_timesteps:
# temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
@@ -1471,6 +1473,7 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1 val_ts_step += 1
if is_tracking: if is_tracking:
@@ -1516,7 +1519,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep args.min_timestep = args.max_timestep = timestep
# temporary, for batch processing # temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
@@ -1551,6 +1554,7 @@ class NetworkTrainer:
} }
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1 val_ts_step += 1
if is_tracking: if is_tracking: