mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
700 lines
28 KiB
Python
700 lines
28 KiB
Python
import torch
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import bitsandbytes.functional as F
|
|
from torch.nn.functional import normalize
|
|
from dataclasses import dataclass
|
|
import math
|
|
|
|
@dataclass
|
|
class OptimizerConfig:
|
|
"""Configuration for the Automagic_CameAMP optimizer."""
|
|
lr: float = 1e-6
|
|
min_lr: float = 1e-7
|
|
max_lr: float = 1e-3
|
|
lr_bump: float = 3e-6
|
|
eps: Tuple[float, float, float] = (1e-30, 1e-16, 1e-8)
|
|
clip_threshold: float = 1.0
|
|
betas: Tuple[float, float, float] = (0.8, 0.99, 0.999)
|
|
eta: float = 2.0
|
|
beta1_decay: float = 0.9995
|
|
weight_decay: float = 1.0
|
|
warmup_steps: int = 500
|
|
cautious: bool = True
|
|
full_finetune: bool = False
|
|
verbose: bool = False
|
|
|
|
class BaseOptimizer(torch.optim.Optimizer):
|
|
"""Base class for Automagic optimizers with common functionality."""
|
|
|
|
def __init__(self, params, config: OptimizerConfig):
|
|
self.config = config
|
|
# Handle eta value: if not float use 2.0
|
|
eta_value = float(config.eta) if isinstance(config.eta, (int, float)) else 2.0
|
|
|
|
defaults = dict(
|
|
lr=config.lr,
|
|
eps=config.eps,
|
|
clip_threshold=config.clip_threshold,
|
|
betas=config.betas,
|
|
eta=eta_value,
|
|
beta1_decay=config.beta1_decay,
|
|
weight_decay=config.weight_decay,
|
|
warmup_steps=config.warmup_steps,
|
|
cautious=config.cautious,
|
|
full_finetune=config.full_finetune,
|
|
)
|
|
super().__init__(params, defaults)
|
|
self.base_lrs: List[float] = [config.lr for group in self.param_groups]
|
|
|
|
@staticmethod
|
|
def _rms(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Calculate root mean square of tensor."""
|
|
return tensor.norm(2) / (tensor.numel() ** 0.5 + 1e-10)
|
|
|
|
@staticmethod
|
|
def _approx_sq_grad(exp_avg_sq_row: torch.Tensor, exp_avg_sq_col: torch.Tensor) -> torch.Tensor:
|
|
"""Approximate square gradient for factored matrices."""
|
|
r_factor = (exp_avg_sq_row / (exp_avg_sq_row.mean(dim=-1, keepdim=True) + 1e-12)).rsqrt_().unsqueeze(-1)
|
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
|
return torch.mul(r_factor, c_factor)
|
|
|
|
@staticmethod
|
|
def _ratio(new_p: torch.Tensor, p: torch.Tensor, pre: torch.Tensor) -> torch.Tensor:
|
|
"""Calculate the ratio for selective projection decay."""
|
|
curr_norm, prev_norm = torch.norm(new_p - pre), torch.norm(p - pre)
|
|
ratio = (curr_norm - prev_norm) / (curr_norm + 1e-8)
|
|
return torch.nn.functional.hardtanh(ratio, 0.0, 1.0)
|
|
|
|
# Implementation from: https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability/blob/main/orthograd.py
|
|
@staticmethod
|
|
def orthograd_(p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
|
|
if p.norm(2) <= 1e-30:
|
|
return grad
|
|
|
|
G_shape = grad.shape
|
|
w = p.view(-1)
|
|
g = grad.view(-1)
|
|
g_norm = g.norm(2)
|
|
|
|
proj = torch.dot(w, g) / torch.dot(w, w).add(1e-30)
|
|
g_orth = g.sub_(w, alpha=proj)
|
|
g_orth_scaled = g_orth.mul_(g_norm / g_orth.norm(2).add(1e-30))
|
|
|
|
return g_orth_scaled.view(G_shape)
|
|
|
|
@staticmethod
|
|
def _should_apply_spd(state: Dict[str, Any], group: Dict[str, Any], p: torch.Tensor, grad: torch.Tensor) -> bool:
|
|
"""Determine if SPD (Selective Projection Decay) should be applied.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
p: Parameter tensor
|
|
grad: Gradient tensor
|
|
|
|
Returns:
|
|
bool: True if SPD should be applied, False otherwise
|
|
"""
|
|
if state["step"] >= group["warmup_steps"]:
|
|
return False
|
|
|
|
pre = state["pre"] if state["pre"] is not None else torch.zeros_like(p)
|
|
condition = -torch.sum(grad * (p - pre))
|
|
return condition < 0.0
|
|
|
|
@staticmethod
|
|
def _update_torque_aware_momentum(state: Dict[str, Any], scaled_grad: torch.Tensor, eps1: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Update momentum using Torque-Aware Momentum during early training.
|
|
|
|
Implementation from:
|
|
https://arxiv.org/abs/2412.18790
|
|
https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/optimizer/tam.py
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
scaled_grad: Scaled gradient tensor
|
|
eps1: Epsilon parameter for numerical stability
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: Updated momentum average
|
|
"""
|
|
# Set fixed beta values for early training
|
|
beta1, beta2, beta3 = 0.9, 0.999, 0.9999
|
|
decay_rate = 0.9
|
|
|
|
# Get state tensors
|
|
s, exp_avg = state['s'], state['exp_avg']
|
|
|
|
# Calculate correlation between normalized momentum and gradient
|
|
corr = normalize(exp_avg, p=2.0, dim=0).mul_(normalize(scaled_grad, p=2.0, dim=0))
|
|
|
|
# Update correlation state
|
|
s.mul_(decay_rate).add_(corr, alpha=1.0 - decay_rate)
|
|
|
|
# Calculate torque-aware update
|
|
d = ((1.0 + s) / 2.0).add_(eps1).mul_(scaled_grad)
|
|
|
|
# Calculate momentum average
|
|
exp_avg_bar = exp_avg * beta1 + scaled_grad * (1 - beta1)
|
|
|
|
# Update momentum
|
|
exp_avg.mul_(beta1).add_(d)
|
|
|
|
return exp_avg_bar, exp_avg
|
|
|
|
@staticmethod
|
|
def _update_consistency_momentum(state: Dict[str, Any], group: Dict[str, Any], scaled_grad: torch.Tensor, beta1: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Update momentum using consistency-based approach after early training.
|
|
|
|
Implementation from:
|
|
Towards Faster Training of Diffusion Models: An Inspiration of A Consistency Phenomenon
|
|
https://arxiv.org/abs/2404.07946
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
scaled_grad: Scaled gradient tensor
|
|
beta1: Beta1 parameter for momentum
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: Updated momentum average
|
|
"""
|
|
# Calculate time-dependent beta1
|
|
beta1_t = max(beta1 * group['beta1_decay'] ** state["step"], 0.4)
|
|
beta1_factor = (1 - beta1) / (1 - beta1_t) if beta1_t < 1 else 1.0
|
|
|
|
# Get momentum state
|
|
exp_avg = state['exp_avg']
|
|
|
|
# Calculate momentum average
|
|
exp_avg_bar = exp_avg * beta1 + scaled_grad * (1 - beta1)
|
|
|
|
# Update momentum with time-dependent beta1
|
|
exp_avg.mul_(beta1_t).add_(scaled_grad, alpha=1 - beta1_t)
|
|
|
|
return exp_avg_bar, exp_avg
|
|
|
|
@staticmethod
|
|
def _update_post_warmup_lr_mask(state: Dict[str, Any], group: Dict[str, Any]) -> torch.Tensor:
|
|
"""Update learning rate mask after warmup phase.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
|
|
Returns:
|
|
torch.Tensor: Updated learning rate mask
|
|
"""
|
|
new_lr = state['lr_mask']
|
|
|
|
# Update maximum learning rate if needed
|
|
if group["lr"] > state["lr_max"]:
|
|
state["lr_max"] = group["lr"]
|
|
|
|
# Scale learning rate if current lr is less than maximum
|
|
if group["lr"] < state["lr_max"]:
|
|
new_lr = new_lr * max((group["lr"] / state["lr_max"]), 0.1)
|
|
|
|
return new_lr
|
|
|
|
def _get_group_lr(self, group: Dict[str, Any]) -> float:
|
|
"""Get the average learning rate for a parameter group."""
|
|
group_lrs = []
|
|
for p in group["params"]:
|
|
state = self.state[p]
|
|
if 'avg_lr' in state:
|
|
group_lrs.append(state['avg_lr'])
|
|
return float(torch.mean(torch.tensor(group_lrs))) if group_lrs else self.config.lr
|
|
|
|
def _init_state(self, p: torch.Tensor, group: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Initialize optimizer state for a parameter."""
|
|
device = p.device
|
|
shape = p.shape
|
|
state = self.state[p]
|
|
|
|
# Basic state initialization
|
|
state.setdefault("lr_max", 1e-6)
|
|
state.setdefault("step", 0)
|
|
|
|
# Learning rate mask initialization
|
|
state.setdefault('lr_mask', torch.ones(shape, device=device, dtype=torch.float32) * self.config.lr)
|
|
state.setdefault('avg_lr', float(self.config.lr))
|
|
state.setdefault('last_polarity', torch.zeros(shape, dtype=torch.bool, device=device))
|
|
|
|
# Momentum and variance initialization
|
|
state.setdefault("exp_avg", torch.zeros_like(p))
|
|
state.setdefault("s", torch.zeros_like(p))
|
|
state.setdefault("exp_avg_sq", torch.zeros_like(p))
|
|
state.setdefault("exp_avg_res", torch.zeros_like(p))
|
|
|
|
# Full finetune initialization
|
|
if group is not None and group['full_finetune'] is False:
|
|
state.setdefault("pre", None)
|
|
"""
|
|
==== ALLoRA ====
|
|
ALLoRA: Adaptive Learning Rate Mitigates LoRA Fatal Flaws
|
|
https://arxiv.org/abs/2410.09692
|
|
"""
|
|
if len(p.shape) == 2:
|
|
row_norm = p.norm(dim=1, keepdim=True)
|
|
state["row_scaling"] = 1.0 / torch.sqrt(row_norm + 1.0 / (group['eta']**2))
|
|
else:
|
|
state.setdefault("pre", p.clone())
|
|
|
|
def _apply_spd_update(self, p: torch.Tensor, update_p: torch.Tensor, state: Dict[str, Any], group: Dict[str, Any]) -> None:
|
|
"""Apply SPD (Selective Projection Decay) update to the parameter.
|
|
|
|
Args:
|
|
p: Parameter tensor to update
|
|
update_p: Update parameter tensor
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
"""
|
|
pre = state["pre"] if state["pre"] is not None else torch.zeros_like(p)
|
|
# Calculate new parameter value
|
|
new_p = p - update_p
|
|
|
|
# Apply selective projection decay
|
|
ratio = self._ratio(new_p, p, pre)
|
|
new_p = new_p - group["weight_decay"] * ratio * (new_p - pre)
|
|
p.copy_(new_p)
|
|
|
|
def _update_learning_rate_mask(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor) -> torch.Tensor:
|
|
"""Update the learning rate mask based on gradient polarity and current state.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
grad: Gradient tensor
|
|
|
|
Returns:
|
|
torch.Tensor: Updated learning rate mask
|
|
"""
|
|
if state["step"] < group["warmup_steps"]:
|
|
return self._update_warmup_lr_mask(state, group, grad)
|
|
else:
|
|
return self._update_post_warmup_lr_mask(state, group)
|
|
|
|
def _update_warmup_lr_mask(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor) -> torch.Tensor:
|
|
"""Update learning rate mask during warmup phase.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
grad: Gradient tensor
|
|
|
|
Returns:
|
|
torch.Tensor: Updated learning rate mask
|
|
"""
|
|
# Update polarity tracking
|
|
last_polarity = state['last_polarity']
|
|
current_polarity = (grad > 0)
|
|
sign_agree = torch.where(last_polarity == current_polarity, 1.0, -1.0)
|
|
state['last_polarity'] = current_polarity
|
|
|
|
# Calculate new learning rate
|
|
lr_mask = state['lr_mask']
|
|
new_lr = torch.where(
|
|
sign_agree > 0,
|
|
lr_mask + self.config.lr_bump,
|
|
lr_mask - self.config.lr_bump
|
|
)
|
|
|
|
# Handle learning rate maximum
|
|
if group["lr"] > state["lr_max"]:
|
|
new_lr = new_lr + (group["lr"] - state["lr_max"])
|
|
state["lr_max"] = group["lr"]
|
|
|
|
# Clamp learning rate to valid range
|
|
new_lr = torch.clamp(new_lr, min=self.config.min_lr, max=self.config.max_lr)
|
|
|
|
# Update state
|
|
state['lr_mask'] = new_lr
|
|
state['avg_lr'] = torch.mean(new_lr).item()
|
|
|
|
return new_lr
|
|
|
|
def _update_momentum(self, state: Dict[str, Any], group: Dict[str, Any], scaled_grad: torch.Tensor, beta1: float, eps1: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Update momentum based on current training phase.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
scaled_grad: Scaled gradient tensor
|
|
beta1: Beta1 parameter for momentum
|
|
eps1: Epsilon parameter for numerical stability
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: Updated momentum average
|
|
"""
|
|
if state["step"] < group["warmup_steps"] / 2:
|
|
exp_avg_bar, exp_avg = self._update_torque_aware_momentum(state, scaled_grad, eps1)
|
|
else:
|
|
exp_avg_bar, exp_avg = self._update_consistency_momentum(state, group, scaled_grad, beta1)
|
|
|
|
return exp_avg_bar, exp_avg
|
|
|
|
class Automagic_CameAMP(BaseOptimizer):
|
|
"""Automagic_CameAMP optimizer implementation."""
|
|
|
|
def __init__(self, params, **kwargs):
|
|
config = OptimizerConfig(**kwargs)
|
|
super().__init__(params, config)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
|
"""Perform a single optimization step."""
|
|
loss = closure() if closure is not None else None
|
|
|
|
for group in self.param_groups:
|
|
grads_this_group = []
|
|
for p in group["params"]:
|
|
if p.grad is None or not p.requires_grad:
|
|
continue
|
|
grads_this_group.append(p.grad.view(-1))
|
|
all_group_grads = torch.cat(grads_this_group)
|
|
sum_abs_all_group_grads = torch.sum(torch.abs(all_group_grads))
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None or not p.requires_grad:
|
|
continue
|
|
|
|
# === state 初始化 ===
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
self._init_state(p, group)
|
|
|
|
if 'step' not in state:
|
|
state['step'] = 0
|
|
state["step"] += 1
|
|
if state["step"] == group["warmup_steps"]:
|
|
if 's' in state:
|
|
del state['s']
|
|
if 'last_polarity' in state:
|
|
del state['last_polarity']
|
|
if 'pre' in state and state["pre"] is not None:
|
|
del state['pre']
|
|
|
|
"""
|
|
=== grad 初始化 ===
|
|
==== AGR自適應梯度正則 ====
|
|
Adaptive Gradient Regularization: A Faster and Generalizable Optimization Technique for Deep Neural Networks
|
|
https://arxiv.org/pdf/2407.16944
|
|
"""
|
|
grad = p.grad
|
|
abs_grad = torch.abs(grad)
|
|
alpha = abs_grad / sum_abs_all_group_grads
|
|
grad = grad * (1 - alpha)
|
|
|
|
beta1, beta2, beta3 = group["betas"]
|
|
eps1, eps2, eps3 = group["eps"]
|
|
|
|
"""
|
|
ADOPT: Modified Adam Can Converge with Any β_2 with the Optimal Rate
|
|
https://arxiv.org/abs/2411.02853
|
|
https://github.com/iShohei220/adopt
|
|
"""
|
|
exp_avg_sq = state["exp_avg_sq"]
|
|
|
|
if state['step'] == 1:
|
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
|
continue
|
|
|
|
de_nom = exp_avg_sq.sqrt().clamp_(min=1e-6)
|
|
scaled_grad = grad.div(de_nom)
|
|
clip = state['step'] ** 0.25
|
|
scaled_grad.clamp_(-clip, clip)
|
|
|
|
"""
|
|
==== Momentum Update ====
|
|
"""
|
|
exp_avg_bar, exp_avg = self._update_momentum(state, group, scaled_grad, beta1, eps1)
|
|
|
|
"""
|
|
==== CAME 核心區塊 (always non-factored) ====
|
|
CAME: Confidence-guided Adaptive Memory Efficient Optimization
|
|
https://arxiv.org/pdf/2411.02853
|
|
https://github.com/yangluo7/CAME
|
|
"""
|
|
exp_avg_res = state["exp_avg_res"]
|
|
res = (scaled_grad - exp_avg_bar).pow(2) + eps2
|
|
exp_avg_res.mul_(beta3).add_(res, alpha=1.0 - beta3)
|
|
update_p = exp_avg.clone().mul_(exp_avg_res.rsqrt())
|
|
|
|
"""
|
|
==== Automagic lrmask ====
|
|
https://github.com/ostris/ai-toolkit/blob/main/toolkit/optimizers/automagic.py
|
|
"""
|
|
new_lr = self._update_learning_rate_mask(state, group, grad)
|
|
|
|
if state["step"] < group["warmup_steps"] / 2:
|
|
"""
|
|
=== Grams ===
|
|
Grams: Gradient Descent with Adaptive Momentum Scaling
|
|
https://arxiv.org/abs/2412.17107
|
|
https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/optimizer/grams.py
|
|
"""
|
|
update_p.abs_().mul_(grad.sign())
|
|
else:
|
|
if self.config.cautious:
|
|
"""
|
|
=== Cautious ===
|
|
Cautious Optimizers: Improving Training with One Line of Code
|
|
https://arxiv.org/abs/2411.16085
|
|
https://github.com/kyleliang919/C-Optim
|
|
"""
|
|
mask = (update_p * grad > 0).to(grad.dtype)
|
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
|
update_p = (update_p * mask)
|
|
|
|
"""
|
|
=== 正交梯度 ===
|
|
Grokking at the Edge of Numerical Stability
|
|
|
|
https://arxiv.org/abs/2501.04697
|
|
https://github.com/LoganBooker/prodigy-plus-schedule-free/tree/dev
|
|
"""
|
|
if state["step"] < group["warmup_steps"] / 2:
|
|
update_p = self.orthograd_(p, update_p)
|
|
if "row_scaling" in state:
|
|
update_p = update_p * state["row_scaling"]
|
|
|
|
update_p = update_p.mul(new_lr)
|
|
|
|
"""
|
|
=== SPD 選擇性投影 decay ===
|
|
Rethinking Weight Decay for Robust Fine-Tuning of Foundation Models
|
|
https://arxiv.org/abs/2411.01713
|
|
https://github.com/GT-RIPL/Selective-Projection-Decay/tree/main
|
|
Mirror, Mirror of the Flow: How Does Regularization Shape Implicit Bias?
|
|
https://arxiv.org/abs/2504.12883
|
|
"""
|
|
do_spd = self._should_apply_spd(state, group, p, grad)
|
|
if do_spd:
|
|
self._apply_spd_update(p, update_p, state, group)
|
|
else:
|
|
p.add_(-update_p)
|
|
|
|
if self.config.verbose:
|
|
print([group["lr"] for group in self.param_groups])
|
|
return loss
|
|
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
"""Get the optimizer state dictionary."""
|
|
state = super().state_dict()
|
|
state['magic_version'] = 1
|
|
return state
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
"""Load the optimizer state dictionary."""
|
|
if 'magic_version' not in state_dict or state_dict['magic_version'] != 1:
|
|
print('[WARNING] You loaded an unexpected state dict, some dynamic mask parameters may not be properly synchronized!')
|
|
super().load_state_dict(state_dict)
|
|
|
|
class Automagic_CameAMP8bit(BaseOptimizer):
|
|
"""8-bit version of Automagic_CameAMP optimizer."""
|
|
|
|
def __init__(self, params, **kwargs):
|
|
config = OptimizerConfig(**kwargs)
|
|
super().__init__(params, config)
|
|
|
|
def _init_state(self, p: torch.Tensor, group: Optional[Dict[str, Any]] = None) -> None:
|
|
"""Initialize 8-bit optimizer state for a parameter.
|
|
|
|
Args:
|
|
p: Parameter tensor
|
|
group: Parameter group containing optimization settings
|
|
"""
|
|
device = p.device
|
|
shape = p.shape
|
|
state = self.state[p]
|
|
state.setdefault("step", 0)
|
|
|
|
# Initialize quantized learning rate mask
|
|
lr_mask_init = torch.ones(shape, device=device, dtype=torch.float32) * self.config.lr
|
|
q_lr_mask, q_lr_mask_scale = F.quantize_blockwise(lr_mask_init, blocksize=2048)
|
|
state.setdefault('lr_mask_q', q_lr_mask)
|
|
state.setdefault('lr_mask_q_scale', q_lr_mask_scale)
|
|
state.setdefault('avg_lr', float(self.config.lr))
|
|
state.setdefault('last_polarity', torch.zeros(shape, dtype=torch.bool, device=device))
|
|
|
|
# Initialize quantized momentum
|
|
exp_avg_fp32 = torch.zeros_like(p)
|
|
q_exp_avg, q_exp_avg_scale = F.quantize_blockwise(exp_avg_fp32, blocksize=2048)
|
|
state.setdefault("exp_avg_q", q_exp_avg)
|
|
state.setdefault("exp_avg_q_scale", q_exp_avg_scale)
|
|
|
|
# Initialize variance tracking
|
|
if len(shape) >= 2:
|
|
state["exp_avg_sq_row"] = torch.zeros(shape[:-1], device=device)
|
|
state["exp_avg_sq_col"] = torch.zeros(shape[:-2]+shape[-1:], device=device)
|
|
state["exp_avg_res_row"] = torch.zeros(shape[:-1], device=device)
|
|
state["exp_avg_res_col"] = torch.zeros(shape[:-2]+shape[-1:], device=device)
|
|
else:
|
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
state["exp_avg_res"] = torch.zeros_like(p)
|
|
|
|
# Full finetune initialization
|
|
if group is not None and group.get('full_finetune', False):
|
|
state.setdefault("pre", p.clone())
|
|
else:
|
|
state.setdefault("pre", None)
|
|
|
|
@staticmethod
|
|
def _update_post_warmup_lr_mask(state: Dict[str, Any], group: Dict[str, Any]) -> torch.Tensor:
|
|
"""Update learning rate mask after warmup phase.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
|
|
Returns:
|
|
torch.Tensor: Updated learning rate mask
|
|
"""
|
|
new_lr = F.dequantize_blockwise(state['lr_mask_q'], state['lr_mask_q_scale'], blocksize=2048)
|
|
|
|
if group["lr"] < 1e-6:
|
|
new_lr = new_lr * (group["lr"] / 1e-6)
|
|
|
|
return new_lr
|
|
|
|
def _update_warmup_lr_mask(self, state: Dict[str, Any], group: Dict[str, Any], grad: torch.Tensor) -> torch.Tensor:
|
|
"""Update learning rate mask during warmup phase.
|
|
|
|
Args:
|
|
state: Optimizer state for the parameter
|
|
group: Parameter group containing optimization settings
|
|
grad: Gradient tensor
|
|
|
|
Returns:
|
|
torch.Tensor: Updated learning rate mask
|
|
"""
|
|
# Update polarity tracking
|
|
last_polarity = state['last_polarity']
|
|
current_polarity = (grad > 0)
|
|
sign_agree = torch.where(last_polarity == current_polarity, 1, -1)
|
|
state['last_polarity'] = current_polarity
|
|
|
|
# Calculate new learning rate
|
|
lr_mask = F.dequantize_blockwise(state['lr_mask_q'], state['lr_mask_q_scale'], blocksize=2048)
|
|
new_lr = torch.where(
|
|
sign_agree > 0,
|
|
lr_mask + self.config.lr_bump,
|
|
lr_mask - self.config.lr_bump
|
|
)
|
|
|
|
# Clamp learning rate to valid range
|
|
new_lr = torch.clamp(new_lr, min=self.config.min_lr, max=self.config.max_lr)
|
|
|
|
# Update quantized learning rate mask
|
|
q_lr_mask, q_lr_mask_scale = F.quantize_blockwise(new_lr, blocksize=2048)
|
|
state['lr_mask_q'] = q_lr_mask
|
|
state['lr_mask_q_scale'] = q_lr_mask_scale
|
|
state['avg_lr'] = torch.mean(new_lr).item()
|
|
|
|
return new_lr
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
|
"""Perform a single optimization step with 8-bit quantization."""
|
|
loss = closure() if closure is not None else None
|
|
|
|
for group in self.param_groups:
|
|
for p in group["params"]:
|
|
if p.grad is None or not p.requires_grad:
|
|
continue
|
|
|
|
grad = p.grad
|
|
state = self.state[p]
|
|
factored = len(p.shape) >= 2
|
|
|
|
# Initialize state if needed
|
|
if len(state) == 0:
|
|
self._init_state(p, group)
|
|
|
|
if 'step' not in state:
|
|
state['step'] = 0
|
|
state["step"] += 1
|
|
|
|
beta1, beta2, beta3 = group["betas"]
|
|
eps1, eps2 = group["eps"]
|
|
|
|
# Adafactor/RMS core
|
|
update_p = grad.pow(2) + eps1
|
|
if factored:
|
|
exp_avg_sq_row = state["exp_avg_sq_row"]
|
|
exp_avg_sq_col = state["exp_avg_sq_col"]
|
|
exp_avg_sq_row.mul_(beta2).add_(update_p.mean(dim=-1), alpha=1 - beta2)
|
|
exp_avg_sq_col.mul_(beta2).add_(update_p.mean(dim=-2), alpha=1 - beta2)
|
|
update_p = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
|
update_p.mul_(grad)
|
|
else:
|
|
exp_avg_sq = state["exp_avg_sq"]
|
|
exp_avg_sq.mul_(beta2).add_(update_p, alpha=1 - beta2)
|
|
update_p = grad.clone().mul_(exp_avg_sq.rsqrt())
|
|
update_p.div_((self._rms(update_p) / group["clip_threshold"]).clamp_(min=1.0))
|
|
|
|
# Update momentum
|
|
exp_avg_bar, exp_avg = self._update_momentum(state, group, update_p, beta1, eps1)
|
|
|
|
# Update learning rate mask
|
|
new_lr = self._update_learning_rate_mask(state, group, grad)
|
|
|
|
# Apply update
|
|
update_p = update.mul(new_lr)
|
|
|
|
# SPD
|
|
do_spd = self._should_apply_spd(state, group, p, grad)
|
|
if do_spd:
|
|
self._apply_spd_update(p, update_p, state, group)
|
|
else:
|
|
p.add_(-update_p)
|
|
|
|
if self.config.verbose:
|
|
print([group["lr"] for group in self.param_groups])
|
|
return loss
|
|
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
"""Get the 8-bit optimizer state dictionary."""
|
|
orig_sd = super().state_dict()
|
|
new_state = {}
|
|
for k, v in orig_sd['state'].items():
|
|
# Don't store unquantized tensors
|
|
save_state = {kk: vv for kk, vv in v.items()
|
|
if kk not in ('lr_mask', 'exp_avg_q', 'exp_avg_q_scale', 'lr_mask_q', 'lr_mask_q_scale')}
|
|
# Save quantized tensors
|
|
if 'exp_avg_q' in v and 'exp_avg_q_scale' in v:
|
|
save_state['exp_avg_q'] = v['exp_avg_q']
|
|
save_state['exp_avg_q_scale'] = v['exp_avg_q_scale']
|
|
if 'lr_mask_q' in v and 'lr_mask_q_scale' in v:
|
|
save_state['lr_mask_q'] = v['lr_mask_q']
|
|
save_state['lr_mask_q_scale'] = v['lr_mask_q_scale']
|
|
new_state[k] = save_state
|
|
orig_sd['state'] = new_state
|
|
orig_sd['magic8_version'] = 1
|
|
return orig_sd
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
"""Load the 8-bit optimizer state dictionary."""
|
|
if 'magic8_version' not in state_dict or state_dict['magic8_version'] != 1:
|
|
print('[WARNING] You loaded an unexpected state dict, some 8-bit parameters may not be synchronized!')
|
|
basic_sd = {'state': {}, 'param_groups': state_dict['param_groups']}
|
|
for k, v in state_dict['state'].items():
|
|
basic_sd['state'][k] = {kk: vv for kk, vv in v.items()
|
|
if kk not in ('exp_avg_q', 'exp_avg_q_scale', 'lr_mask_q', 'lr_mask_q_scale')}
|
|
super().load_state_dict(basic_sd)
|
|
# Restore quantized tensors
|
|
param_map = [p for g in self.param_groups for p in g['params']]
|
|
for idx, p in enumerate(param_map):
|
|
idx_str = str(idx)
|
|
if idx_str not in state_dict['state']:
|
|
continue
|
|
src = state_dict['state'][idx_str]
|
|
st = self.state[p]
|
|
if 'exp_avg_q' in src and 'exp_avg_q_scale' in src:
|
|
st['exp_avg_q'] = src['exp_avg_q']
|
|
st['exp_avg_q_scale'] = src['exp_avg_q_scale']
|
|
if 'lr_mask_q' in src and 'lr_mask_q_scale' in src:
|
|
st['lr_mask_q'] = src['lr_mask_q']
|
|
st['lr_mask_q_scale'] = src['lr_mask_q_scale']
|