use same noise for every validation

This commit is contained in:
Kohya S
2025-01-27 22:10:38 +09:00
parent 42c0a9e1fc
commit 45ec02b2a8
2 changed files with 6 additions and 1 deletions

View File

@@ -377,7 +377,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(

View File

@@ -1391,6 +1391,8 @@ class NetworkTrainer:
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_state = torch.get_rng_state()
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_total_steps),
@@ -1451,6 +1453,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
torch.set_rng_state(rng_state)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
@@ -1467,6 +1470,8 @@ class NetworkTrainer:
if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_state = torch.get_rng_state()
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_total_steps),
@@ -1531,6 +1536,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
torch.set_rng_state(rng_state)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()