mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: validation with block swap
This commit is contained in:
@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -507,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -445,15 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
# TODO consider validation
|
||||
# drop cached text encoder outputs
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -309,7 +309,10 @@ class NetworkTrainer:
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||
pass
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
@@ -1278,7 +1281,7 @@ class NetworkTrainer:
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
if accelerator.device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
@@ -1330,8 +1333,8 @@ class NetworkTrainer:
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
@@ -1434,8 +1437,7 @@ class NetworkTrainer:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
@@ -1471,6 +1473,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
@@ -1516,7 +1519,7 @@ class NetworkTrainer:
|
||||
args.min_timestep = args.max_timestep = timestep
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
@@ -1551,6 +1554,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
|
||||
Reference in New Issue
Block a user