mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix: validation with block swap
This commit is contained in:
@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.is_schnell: Optional[bool] = None
|
self.is_schnell: Optional[bool] = None
|
||||||
self.is_swapping_blocks: bool = False
|
self.is_swapping_blocks: bool = False
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -507,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoder.to(te_weight_dtype) # fp8
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
prepare_fp8(text_encoder, weight_dtype)
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
def prepare_unet_with_accelerator(
|
def prepare_unet_with_accelerator(
|
||||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample_prompts_te_outputs = None
|
self.sample_prompts_te_outputs = None
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
# super().assert_extra_args(args, train_dataset_group)
|
# super().assert_extra_args(args, train_dataset_group)
|
||||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -445,15 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoder.to(te_weight_dtype) # fp8
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
prepare_fp8(text_encoder, weight_dtype)
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||||
# TODO consider validation
|
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||||
# drop cached text encoder outputs
|
|
||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
def prepare_unet_with_accelerator(
|
def prepare_unet_with_accelerator(
|
||||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
|||||||
@@ -309,7 +309,10 @@ class NetworkTrainer:
|
|||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
return accelerator.prepare(unet)
|
return accelerator.prepare(unet)
|
||||||
|
|
||||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@@ -1330,8 +1333,8 @@ class NetworkTrainer:
|
|||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start_for_network(text_encoder, unet)
|
on_step_start_for_network(text_encoder, unet)
|
||||||
|
|
||||||
# temporary, for batch processing
|
# preprocess batch for each model
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
@@ -1434,8 +1437,7 @@ class NetworkTrainer:
|
|||||||
break
|
break
|
||||||
|
|
||||||
for timestep in validation_timesteps:
|
for timestep in validation_timesteps:
|
||||||
# temporary, for batch processing
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
|
||||||
|
|
||||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||||
|
|
||||||
@@ -1471,6 +1473,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_ts_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
@@ -1516,7 +1519,7 @@ class NetworkTrainer:
|
|||||||
args.min_timestep = args.max_timestep = timestep
|
args.min_timestep = args.max_timestep = timestep
|
||||||
|
|
||||||
# temporary, for batch processing
|
# temporary, for batch processing
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
@@ -1551,6 +1554,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
val_ts_step += 1
|
val_ts_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
|
|||||||
Reference in New Issue
Block a user