mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Add gradient noise scale logging
This commit is contained in:
@@ -4125,6 +4125,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
default=None,
|
default=None,
|
||||||
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--gradient_noise_scale", action="store_true", default=False, help="Calculate the gradient noise scale")
|
||||||
|
|
||||||
if support_dreambooth:
|
if support_dreambooth:
|
||||||
# DreamBooth training
|
# DreamBooth training
|
||||||
|
|||||||
@@ -293,6 +293,10 @@ class LoRAModule(torch.nn.Module):
|
|||||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
def accumulate_grad(self):
|
||||||
|
for param in self.parameters():
|
||||||
|
if param.grad is not None:
|
||||||
|
self.all_grad.append(param.grad.view(-1))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
@@ -976,6 +980,31 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||||
|
|
||||||
|
def accumulate_grad(self):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.accumulate_grad()
|
||||||
|
|
||||||
|
def all_grad(self):
|
||||||
|
all_grad = []
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
all_grad.append(lora.all_grad)
|
||||||
|
|
||||||
|
return torch.stack(all_grad)
|
||||||
|
|
||||||
|
def gradient_noise_scale(self):
|
||||||
|
mean_grad = torch.mean(self.all_grads(), dim=0)
|
||||||
|
|
||||||
|
# Calculate trace of covariance matrix
|
||||||
|
centered_grads = all_grads - mean_grad
|
||||||
|
trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1))
|
||||||
|
|
||||||
|
# Calculate norm of mean gradient squared
|
||||||
|
grad_norm_squared = torch.sum(mean_grad**2)
|
||||||
|
|
||||||
|
# Calculate GNS using provided gradient norm squared
|
||||||
|
gradient_noise_scale = trace_cov / grad_norm_squared
|
||||||
|
|
||||||
|
return gradient_noise_scale.item()
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
|||||||
@@ -1418,6 +1418,8 @@ class NetworkTrainer:
|
|||||||
network.update_grad_norms()
|
network.update_grad_norms()
|
||||||
if hasattr(network, "update_norms"):
|
if hasattr(network, "update_norms"):
|
||||||
network.update_norms()
|
network.update_norms()
|
||||||
|
if args.gradient_noise_scale and hasattr(network, "accumulate_grad"):
|
||||||
|
network.accumulate_grad()
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
@@ -1491,6 +1493,8 @@ class NetworkTrainer:
|
|||||||
mean_grad_norm,
|
mean_grad_norm,
|
||||||
mean_combined_norm,
|
mean_combined_norm,
|
||||||
)
|
)
|
||||||
|
if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"):
|
||||||
|
logs = {**logs, "grad/noise_scale": self.gradient_noise_scale()}
|
||||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
# VALIDATION PER STEP: global_step is already incremented
|
# VALIDATION PER STEP: global_step is already incremented
|
||||||
|
|||||||
Reference in New Issue
Block a user