Compare commits

...

2 Commits

Author SHA1 Message Date
ddpasa
fe96db6802 Merge 1dc45f481a into 1dae34b0af 2026-03-30 19:26:43 +00:00
name
1dc45f481a train_network.py: add debugging print for tensor and gradient norms 2025-06-28 16:42:32 +02:00

View File

@@ -1439,6 +1439,24 @@ class NetworkTrainer:
if hasattr(network, "update_norms"):
network.update_norms()
if (args.debug_info_steps > 0) and (step % args.debug_info_steps == 0):
params = accelerator.unwrap_model(network).get_trainable_params()
grads, weights, numels = [], [], []
for p in params:
if p.requires_grad:
p_detached = p.detach()
weights.append(p_detached.norm(p=1).item())
if p.grad is not None:
grads.append(p.grad.detach().norm(p=1).item())
else:
grads.append(0.0)
numels.append(p_detached.numel())
total_grad = sum(grads) / sum(numels)
total_weight = sum(weights) / sum(numels)
accelerator.print(
f"\n[Step {step}] avr_grad={total_grad:.4E}, avr_weights={total_weight:.4E}"
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1734,6 +1752,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing for U-Net or DiT, if supported"
" / 勾配チェックポイント時にテンソルをCPUにオフロードするU-NetまたはDiTのみ、サポートされている場合",
)
parser.add_argument(
"--debug_info_steps", type=int, default=0, help="Log gradient/weight norms every N steps"
)
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)