diff --git a/library/galore_optimizer.py b/library/galore_optimizer.py new file mode 100644 index 00000000..0f98e7cb --- /dev/null +++ b/library/galore_optimizer.py @@ -0,0 +1,798 @@ +# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/ +# original license is Apache License 2.0 +import ast +import math +import warnings +from typing import Callable, Dict, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer + +from transformers.utils.versions import require_version + +from library import train_util + +from .utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class GaLoreProjector: + def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"): + self.rank = rank + self.verbose = verbose + self.update_proj_gap = update_proj_gap + self.scale = scale + self.ortho_matrix = None + self.proj_type = proj_type + + def project(self, full_rank_grad, iter): + + if self.proj_type == "std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "reverse_std": + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + else: + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "right": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + elif self.proj_type == "left": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + elif self.proj_type == "full": + if self.ortho_matrix is None or iter % self.update_proj_gap == 0: + self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full") + low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if self.proj_type == "std": + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "reverse_std": + if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + else: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "right": + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + elif self.proj_type == "left": + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + elif self.proj_type == "full": + full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] + + return full_rank_grad * self.scale + + # svd decomposition + def get_orthogonal_matrix(self, weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix) + + # make the smaller matrix always to be orthogonal matrix + if type == "right": + A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == "left": + A = U[:, :rank] + B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == "full": + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError("type should be left, right or full") + + +class GaLoreAdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + no_deprecation_warning: bool = False, + ): + if not no_deprecation_warning: + warnings.warn( + "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" + " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" + " warning", + FutureWarning, + ) + require_version("torch>=1.5.0") # add_ with alpha + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + grad = state["projector"].project(grad, state["step"]) + + # State initialization + if "exp_avg" not in state: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1 ** state["step"] + bias_correction2 = 1.0 - beta2 ** state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # compute norm gradient + norm_grad = exp_avg / denom + + # GaLore Projection Back + if "rank" in group: + norm_grad = state["projector"].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +class GaLoreAdafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults to 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0.0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + + - Training without LR warmup or clip_threshold is not recommended. + + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + + Example: + + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + + Others reported the following combination to work well: + + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + + Usage: + + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + # make default to be the same as trainer + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=False, + relative_step=False, + warmup_init=False, + ): + # scale_parameter=True, + # relative_step=True, + + require_version("torch>=1.5.0") # add_ with alpha + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = { + "lr": lr, + "eps": eps, + "clip_threshold": clip_threshold, + "decay_rate": decay_rate, + "beta1": beta1, + "weight_decay": weight_decay, + "scale_parameter": scale_parameter, + "relative_step": relative_step, + "warmup_init": warmup_init, + } + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + grad = state["projector"].project(grad, state["step"]) + + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if "RMS" not in state: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + # GaLore Projection Back + if "rank" in group: + update = state["projector"].project_back(update) + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + return loss + + +try: + from bitsandbytes.optim.optimizer import Optimizer2State +except ImportError: + # define a dummy Optimizer2State class + class Optimizer2State(Optimizer): + def __init__(self, *args, **kwargs): + raise ImportError("Please install bitsandbytes to use this optimizer") + + def step(self, *args, **kwargs): + pass + + def prefetch_state(self, *args, **kwargs): + pass + + def init_state(self, *args, **kwargs): + pass + + def update_step(self, *args, **kwargs): + pass + + def check_overrides(self, *args, **kwargs): + pass + + def to_gpu(self, *args, **kwargs): + pass + + def to_cpu(self, *args, **kwargs): + pass + + +class GaLoreAdamW8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + super().__init__( + "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged + ) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + overflows = [] + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + # if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + if "weight_decay" in group and group["weight_decay"] > 0: + # ensure that the weight decay is not applied to the norm grad + group["weight_decay_saved"] = group["weight_decay"] + group["weight_decay"] = 0 + + grad = state["projector"].project(p.grad, state["step"]) + + # suboptimal implementation + p.saved_data = p.data.clone() + p.data = grad.clone().to(p.data.dtype).to(p.data.device) + p.data.zero_() + p.grad = grad + + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # GaLore Projection Back + if "rank" in group: + p.data = p.saved_data.add_(state["projector"].project_back(p.data)) + + # apply weight decay + if "weight_decay_saved" in group: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"]) + group["weight_decay"] = group["weight_decay_saved"] + del group["weight_decay_saved"] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss + + +def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes): + # trainable_params is list of dict, each dict contains "params" and "lr" + # list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2] + # block lr is not supported + assert len(trainable_params) == len(training_models), "block lr is not supported" + + lr = args.learning_rate + + optimizer_kwargs = {} + if args.optimizer_args is not None and len(args.optimizer_args) > 0: + for arg in args.optimizer_args: + key, value = arg.split("=") + value = ast.literal_eval(value) + optimizer_kwargs[key] = value + + rank = optimizer_kwargs.pop("rank", 128) + update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50) + galore_scale = optimizer_kwargs.pop("galore_scale", 1.0) + proj_type = optimizer_kwargs.pop("proj_type", "std") + weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer + + # make parameters with "rank" to a single group, if param_name has "mlp" or "attn" + # target_modules_list = ["attn", "mlp"] + target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net + + param_groups = [] + param_lr = {} + for model, params in zip(training_models, trainable_params): + logger.info(f"model: {model.__class__.__name__}") + galore_params = [] + group_lr = params.get("lr", lr) + + for module_name, module in model.named_modules(): + if not isinstance(module, nn.Linear): + continue + + if not any(target_key in module_name for target_key in target_modules_list): + continue + + logger.info("enable GaLore for weights in module: " + module_name) + galore_params.append(module.weight) + + id_galore_params = [id(p) for p in galore_params] + # make parameters without "rank" to another group + regular_params = [p for p in params["params"] if id(p) not in id_galore_params] + + # then call galore_adamw + param_groups.append({"params": regular_params, "lr": group_lr}) + + param_groups.append( + { + "params": galore_params, + "rank": rank, + "update_proj_gap": update_proj_gap, + "scale": galore_scale, + "proj_type": proj_type, + "lr": group_lr, + } + ) + + # record lr + for p in regular_params + galore_params: + param_lr[id(p)] = group_lr + + # select optimizer + scheduler = None + if optimizer_type == "galore_adamw": + optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs) + elif optimizer_type == "galore_adafactor": + beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1") + optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs) + elif optimizer_type == "galore_adamw8bit": + optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs) + elif optimizer_type == "galore_adamw8bit_per_layer": + # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap + optimizer_dict = {} + all_params = [] + for params in trainable_params: + all_params.extend(params["params"]) + + for p in all_params: + if p.requires_grad: + if id(p) in id_galore_params: + optimizer_dict[p] = GaLoreAdamW8bit( + [ + { + "params": [p], + "rank": rank, + "update_proj_gap": update_proj_gap * 2, + "scale": galore_scale, + "proj_type": proj_type, + } + ], + lr=param_lr[id(p)], + weight_decay=weight_decay, + ) + else: + import bitsandbytes as bnb + + optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay) + + # get scheduler dict + # scheduler needs accelerate.prepare? + scheduler_dict = {} + for p in all_params: + if p.requires_grad: + scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes) + + def optimizer_hook(p): + if p.grad is None: + return + optimizer_dict[p].step() + optimizer_dict[p].zero_grad() + scheduler_dict[p].step() + + # Register the hook onto every parameter + for p in all_params: + if p.requires_grad: + p.register_post_accumulate_grad_hook(optimizer_hook) + + # make dummy scheduler and optimizer + class DummyScheduler: + def step(self): + pass + + class DummyOptimizer: + def __init__(self, optimizer_dict): + self.optimizer_dict = optimizer_dict + + def step(self): + pass + + def zero_grad(self, set_to_none=False): + pass + + scheduler = DummyScheduler(optimizer_dict[all_params[0]]) + optimizer = DummyOptimizer(optimizer_dict) + + else: + raise ValueError(f"Unsupported optimizer type: {optimizer_type}") + + return optimizer, scheduler diff --git a/library/train_util.py b/library/train_util.py index b71e4edc..2c78e4f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3671,7 +3671,7 @@ def get_optimizer(args, trainable_params): optimizer_class = lion_pytorch.Lion optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - elif optimizer_type.endswith("8bit".lower()): + elif optimizer_type.endswith("8bit".lower()) and not optimizer_type.startswith("GaLore".lower()): try: import bitsandbytes as bnb except ImportError: @@ -3880,6 +3880,11 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type.startswith("GaLore".lower()): + logger.info(f"use GaLore optimizer | {optimizer_kwargs}") + optimizer = "galore" + return None, None, optimizer + if optimizer is None: # 任意のoptimizerを使う optimizer_type = args.optimizer_type # lowerでないやつ(微妙) diff --git a/sdxl_train.py b/sdxl_train.py index e0df263d..50e38921 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,6 +11,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate.utils import set_seed @@ -378,7 +379,17 @@ def train(args): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + lr_scheduler = None + if optimizer == "galore": + from library import galore_optimizer + + # if lr_scheduler is not layerwise, it is None. if layerwise, it is a dummy scheduler + optimizer, lr_scheduler = galore_optimizer.get_optimizer( + args, args.optimizer_type, params_to_optimize, training_models, accelerator.num_processes + ) + + if lr_scheduler is None: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: