mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
use same noise for every validation
This commit is contained in:
@@ -377,7 +377,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
|
|||||||
@@ -1391,6 +1391,8 @@ class NetworkTrainer:
|
|||||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
accelerator.unwrap_model(network).eval()
|
accelerator.unwrap_model(network).eval()
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_total_steps),
|
range(validation_total_steps),
|
||||||
@@ -1451,6 +1453,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
args.max_timestep = original_args_max_timestep
|
args.max_timestep = original_args_max_timestep
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
@@ -1467,6 +1470,8 @@ class NetworkTrainer:
|
|||||||
if should_validate_epoch and len(val_dataloader) > 0:
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
accelerator.unwrap_model(network).eval()
|
accelerator.unwrap_model(network).eval()
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_total_steps),
|
range(validation_total_steps),
|
||||||
@@ -1531,6 +1536,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
args.max_timestep = original_args_max_timestep
|
args.max_timestep = original_args_max_timestep
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
|
|||||||
Reference in New Issue
Block a user