mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix: validation with block swap
This commit is contained in:
@@ -309,7 +309,10 @@ class NetworkTrainer:
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||
pass
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
@@ -1278,7 +1281,7 @@ class NetworkTrainer:
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
if accelerator.device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
@@ -1330,8 +1333,8 @@ class NetworkTrainer:
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
@@ -1434,8 +1437,7 @@ class NetworkTrainer:
|
||||
break
|
||||
|
||||
for timestep in validation_timesteps:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
@@ -1471,6 +1473,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
@@ -1516,7 +1519,7 @@ class NetworkTrainer:
|
||||
args.min_timestep = args.max_timestep = timestep
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
@@ -1551,6 +1554,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_ts_step += 1
|
||||
|
||||
if is_tracking:
|
||||
|
||||
Reference in New Issue
Block a user