mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Compare commits
4 Commits
v0.10.2
...
abacf4978a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abacf4978a | ||
|
|
fa53f71ec0 | ||
|
|
15136ca505 | ||
|
|
3f47806719 |
@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
|
||||
@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
|
||||
@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
use_fp32: bool = False,
|
||||
):
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
@@ -863,11 +863,11 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
@@ -959,6 +959,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -972,6 +973,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -994,6 +996,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
|
||||
|
||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ class LoRAModule(torch.nn.Module):
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
mgpo_rho: float | None = None,
|
||||
mgpo_beta: float | None = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -117,6 +119,25 @@ class LoRAModule(torch.nn.Module):
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
self.mgpo_rho = mgpo_rho
|
||||
self.mgpo_beta = mgpo_beta
|
||||
|
||||
# EMA of gradient magnitudes for adaptive normalization
|
||||
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
|
||||
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
|
||||
|
||||
self.optimizer: torch.optim.Optimizer | None = None
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
@@ -158,6 +179,18 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Momentum-Guided Perturbation Optimization (MGPO)
|
||||
if (
|
||||
self.training
|
||||
and hasattr(self, "mgpo_rho")
|
||||
and self.mgpo_rho is not None
|
||||
and hasattr(self, "optimizer")
|
||||
and self.optimizer is not None
|
||||
):
|
||||
mgpo_perturbation_output = self.get_mgpo_output_perturbation(x)
|
||||
if mgpo_perturbation_output is not None:
|
||||
return org_forwarded + (self.multiplier * scale * lx) + mgpo_perturbation_output
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if (
|
||||
self.training
|
||||
@@ -304,6 +337,97 @@ class LoRAModule(torch.nn.Module):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
def update_gradient_ema(self):
|
||||
"""
|
||||
Update EMA of gradient magnitudes for adaptive perturbation normalization
|
||||
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
|
||||
"""
|
||||
if self.mgpo_beta is None:
|
||||
return
|
||||
|
||||
# Update EMA for lora_down gradient magnitude
|
||||
if self.lora_down.weight.grad is not None:
|
||||
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
|
||||
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
|
||||
current_grad_norm, alpha=(1 - self.mgpo_beta)
|
||||
)
|
||||
|
||||
# Update EMA for lora_up gradient magnitude
|
||||
if self.lora_up.weight.grad is not None:
|
||||
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
|
||||
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
|
||||
current_grad_norm, alpha=(1 - self.mgpo_beta)
|
||||
)
|
||||
|
||||
def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None:
|
||||
"""
|
||||
Generate MGPO perturbation using both momentum direction and gradient magnitude normalization
|
||||
|
||||
Full MGPO Formula: ε = -ρ · (vₜ / ||vₜ||₂) · (ḡₗ⁽ᵗ⁾)⁻¹
|
||||
Where:
|
||||
- ε = perturbation vector
|
||||
- ρ = perturbation radius (mgpo_rho)
|
||||
- vₜ = momentum vector from optimizer (exp_avg) - provides DIRECTION
|
||||
- ||vₜ||₂ = L2 norm of momentum for unit direction
|
||||
- ḡₗ⁽ᵗ⁾ = EMA of gradient magnitude - provides ADAPTIVE SCALING
|
||||
|
||||
Two separate EMAs:
|
||||
1. Momentum EMA (from Adam): vₜ = β₁ * vₜ₋₁ + (1 - β₁) * ∇L(Wₜ)
|
||||
2. Gradient Magnitude EMA: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇L(Wₜ)||₂
|
||||
"""
|
||||
if self.optimizer is None or self.mgpo_rho is None or self.mgpo_beta is None:
|
||||
return None
|
||||
|
||||
total_perturbation_scale = 0.0
|
||||
valid_params = 0
|
||||
|
||||
# Handle both single and split dims cases
|
||||
if self.split_dims is None:
|
||||
params_and_emas = [
|
||||
(self.lora_down.weight, self._grad_magnitude_ema_down),
|
||||
(self.lora_up.weight, self._grad_magnitude_ema_up),
|
||||
]
|
||||
else:
|
||||
# For split dims, use average EMA (or extend to per-param EMAs)
|
||||
avg_ema = (self._grad_magnitude_ema_down + self._grad_magnitude_ema_up) / 2
|
||||
params_and_emas = []
|
||||
for lora_down in self.lora_down:
|
||||
params_and_emas.append((lora_down.weight, avg_ema))
|
||||
for lora_up in self.lora_up:
|
||||
params_and_emas.append((lora_up.weight, avg_ema))
|
||||
|
||||
for param, grad_ema in params_and_emas:
|
||||
if param in self.optimizer.state and "exp_avg" in self.optimizer.state[param]:
|
||||
# Get momentum direction: vₜ / ||vₜ||₂
|
||||
momentum = self.optimizer.state[param]["exp_avg"]
|
||||
momentum_norm = torch.norm(momentum, p=2)
|
||||
|
||||
if momentum_norm > 1e-8 and grad_ema > 1e-8:
|
||||
# Apply full MGPO formula: ρ · (momentum_direction) · (1/grad_magnitude_ema)
|
||||
direction_component = momentum_norm # We'll use this for scaling
|
||||
adaptive_scale = 1.0 / grad_ema # Adaptive normalization
|
||||
|
||||
perturbation_scale = self.mgpo_rho * direction_component * adaptive_scale
|
||||
total_perturbation_scale += perturbation_scale.item()
|
||||
valid_params += 1
|
||||
|
||||
if valid_params == 0:
|
||||
return None
|
||||
|
||||
# Average perturbation scale across all valid parameters
|
||||
avg_perturbation_scale = total_perturbation_scale / valid_params
|
||||
|
||||
with torch.no_grad():
|
||||
# Generate random perturbation scaled by MGPO formula
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(avg_perturbation_scale)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
|
||||
return perturbation_output
|
||||
|
||||
def register_optimizer(self, optimizer):
|
||||
self.optimizer = optimizer
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -574,6 +698,15 @@ def create_network(
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
mgpo_beta = kwargs.get("mgpo_beta", None)
|
||||
mgpo_rho = kwargs.get("mgpo_rho", None)
|
||||
|
||||
if mgpo_beta is not None:
|
||||
mgpo_beta = float(mgpo_beta)
|
||||
|
||||
if mgpo_rho is not None:
|
||||
mgpo_rho = float(mgpo_rho)
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -642,6 +775,8 @@ def create_network(
|
||||
reg_dims=reg_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
mgpo_rho=mgpo_rho,
|
||||
mgpo_beta=mgpo_beta,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
@@ -745,6 +880,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
mgpo_rho: Optional[float] = None,
|
||||
mgpo_beta: Optional[float] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
@@ -790,6 +927,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if mgpo_beta is not None and mgpo_rho is not None:
|
||||
logger.info(f"LoRA-MGPO training rho: {mgpo_rho} beta: {mgpo_beta}")
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
@@ -924,6 +1063,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
mgpo_rho=mgpo_rho,
|
||||
mgpo_beta=mgpo_beta,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
|
||||
119
tests/networks/test_lora_flux_mgpo.py
Normal file
119
tests/networks/test_lora_flux_mgpo.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
from networks.lora_flux import LoRAModule
|
||||
|
||||
|
||||
class MockLinear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
def forward(self, x):
|
||||
return torch.matmul(x, self.weight.t())
|
||||
|
||||
def state_dict(self):
|
||||
return {"weight": self.weight}
|
||||
|
||||
|
||||
class MockOptimizer:
|
||||
def __init__(self, param):
|
||||
self.state = {param: {"exp_avg": torch.randn_like(param)}}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lora_module():
|
||||
org_module = MockLinear(10, 20)
|
||||
lora_module = LoRAModule(org_module, org_module, multiplier=1.0, lora_dim=4, alpha=1.0, mgpo_rho=0.1, mgpo_beta=0.9)
|
||||
# Manually set org_module_shape to match the original module's weight
|
||||
lora_module.org_module_shape = org_module.weight.shape
|
||||
return lora_module
|
||||
|
||||
|
||||
def test_mgpo_parameter_initialization(lora_module):
|
||||
"""Test MGPO-specific parameter initialization."""
|
||||
# Check MGPO-specific attributes
|
||||
assert hasattr(lora_module, "mgpo_rho")
|
||||
assert hasattr(lora_module, "mgpo_beta")
|
||||
assert lora_module.mgpo_rho == 0.1
|
||||
assert lora_module.mgpo_beta == 0.9
|
||||
|
||||
# Check EMA parameters initialization
|
||||
assert hasattr(lora_module, "_grad_magnitude_ema_down")
|
||||
assert hasattr(lora_module, "_grad_magnitude_ema_up")
|
||||
assert isinstance(lora_module._grad_magnitude_ema_down, torch.nn.Parameter)
|
||||
assert isinstance(lora_module._grad_magnitude_ema_up, torch.nn.Parameter)
|
||||
assert lora_module._grad_magnitude_ema_down.requires_grad == False
|
||||
assert lora_module._grad_magnitude_ema_up.requires_grad == False
|
||||
assert lora_module._grad_magnitude_ema_down.item() == 1.0
|
||||
assert lora_module._grad_magnitude_ema_up.item() == 1.0
|
||||
|
||||
|
||||
def test_update_gradient_ema(lora_module):
|
||||
"""Test gradient EMA update method."""
|
||||
# Ensure method works when mgpo_beta is set
|
||||
lora_module.lora_down.weight.grad = torch.randn_like(lora_module.lora_down.weight)
|
||||
lora_module.lora_up.weight.grad = torch.randn_like(lora_module.lora_up.weight)
|
||||
|
||||
# Store initial EMA values
|
||||
initial_down_ema = lora_module._grad_magnitude_ema_down.clone()
|
||||
initial_up_ema = lora_module._grad_magnitude_ema_up.clone()
|
||||
|
||||
# Update gradient EMA
|
||||
lora_module.update_gradient_ema()
|
||||
|
||||
# Check EMA update logic
|
||||
down_grad_norm = torch.norm(lora_module.lora_down.weight.grad, p=2)
|
||||
up_grad_norm = torch.norm(lora_module.lora_up.weight.grad, p=2)
|
||||
|
||||
# Verify EMA calculation
|
||||
expected_down_ema = lora_module.mgpo_beta * initial_down_ema + (1 - lora_module.mgpo_beta) * down_grad_norm
|
||||
expected_up_ema = lora_module.mgpo_beta * initial_up_ema + (1 - lora_module.mgpo_beta) * up_grad_norm
|
||||
|
||||
assert torch.allclose(lora_module._grad_magnitude_ema_down, expected_down_ema, rtol=1e-5)
|
||||
assert torch.allclose(lora_module._grad_magnitude_ema_up, expected_up_ema, rtol=1e-5)
|
||||
|
||||
# Test when mgpo_beta is None
|
||||
lora_module.mgpo_beta = None
|
||||
lora_module.update_gradient_ema() # Should not raise an exception
|
||||
|
||||
|
||||
def test_get_mgpo_output_perturbation(lora_module):
|
||||
"""Test MGPO perturbation generation."""
|
||||
# Create a mock optimizer
|
||||
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)
|
||||
lora_module.register_optimizer(mock_optimizer)
|
||||
|
||||
# Prepare input
|
||||
x = torch.randn(5, 10) # batch × input_dim
|
||||
|
||||
# Ensure method works with valid conditions
|
||||
perturbation = lora_module.get_mgpo_output_perturbation(x)
|
||||
|
||||
# Verify perturbation characteristics
|
||||
assert perturbation is not None
|
||||
assert isinstance(perturbation, torch.Tensor)
|
||||
assert perturbation.shape == (x.shape[0], lora_module.org_module.out_features)
|
||||
|
||||
# Test when conditions are not met
|
||||
lora_module.optimizer = None
|
||||
lora_module.mgpo_rho = None
|
||||
lora_module.mgpo_beta = None
|
||||
|
||||
no_perturbation = lora_module.get_mgpo_output_perturbation(x)
|
||||
assert no_perturbation is None
|
||||
|
||||
|
||||
def test_register_optimizer(lora_module):
|
||||
"""Test optimizer registration method."""
|
||||
# Create a mock optimizer
|
||||
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)
|
||||
|
||||
# Register optimizer
|
||||
lora_module.register_optimizer(mock_optimizer)
|
||||
|
||||
# Verify optimizer is correctly registered
|
||||
assert hasattr(lora_module, "optimizer")
|
||||
assert lora_module.optimizer == mock_optimizer
|
||||
@@ -750,6 +750,9 @@ class NetworkTrainer:
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
|
||||
if hasattr(network, "register_optimizer"):
|
||||
network.register_optimizer(optimizer)
|
||||
|
||||
# prepare dataloader
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
@@ -1438,6 +1441,8 @@ class NetworkTrainer:
|
||||
network.update_grad_norms()
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
if hasattr(network, "update_gradient_ema"):
|
||||
network.update_gradient_ema()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
Reference in New Issue
Block a user