fix: validation with block swap

This commit is contained in:
Kohya S
2025-02-09 21:25:40 +09:00
parent 0911683717
commit 344845b429
3 changed files with 37 additions and 14 deletions

View File

@@ -309,7 +309,10 @@ class NetworkTrainer:
) -> torch.nn.Module:
return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
# endregion
@@ -1278,7 +1281,7 @@ class NetworkTrainer:
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state()
@@ -1330,8 +1333,8 @@ class NetworkTrainer:
with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
batch,
@@ -1434,8 +1437,7 @@ class NetworkTrainer:
break
for timestep in validation_timesteps:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
@@ -1471,6 +1473,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1
if is_tracking:
@@ -1516,7 +1519,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
batch,
@@ -1551,6 +1554,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_ts_step += 1
if is_tracking: