# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/ # original license is Apache License 2.0 import ast import math import warnings from typing import Callable, Dict, Iterable, Tuple import torch from torch import nn from torch.optim import Optimizer from transformers.utils.versions import require_version from library import train_util from .utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) class GaLoreProjector: def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"): self.rank = rank self.verbose = verbose self.update_proj_gap = update_proj_gap self.scale = scale self.ortho_matrix = None self.proj_type = proj_type def project(self, full_rank_grad, iter): if self.proj_type == "std": if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) else: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == "reverse_std": if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) else: if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) elif self.proj_type == "right": if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right") low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) elif self.proj_type == "left": if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left") low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) elif self.proj_type == "full": if self.ortho_matrix is None or iter % self.update_proj_gap == 0: self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full") low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t() return low_rank_grad def project_back(self, low_rank_grad): if self.proj_type == "std": if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) else: full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) elif self.proj_type == "reverse_std": if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) else: full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) elif self.proj_type == "right": full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) elif self.proj_type == "left": full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) elif self.proj_type == "full": full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1] return full_rank_grad * self.scale # svd decomposition def get_orthogonal_matrix(self, weights, rank, type): module_params = weights if module_params.data.dtype != torch.float: float_data = False original_type = module_params.data.dtype original_device = module_params.data.device matrix = module_params.data.float() else: float_data = True matrix = module_params.data U, s, Vh = torch.linalg.svd(matrix) # make the smaller matrix always to be orthogonal matrix if type == "right": A = U[:, :rank] @ torch.diag(s[:rank]) B = Vh[:rank, :] if not float_data: B = B.to(original_device).type(original_type) return B elif type == "left": A = U[:, :rank] B = torch.diag(s[:rank]) @ Vh[:rank, :] if not float_data: A = A.to(original_device).type(original_type) return A elif type == "full": A = U[:, :rank] B = Vh[:rank, :] if not float_data: A = A.to(original_device).type(original_type) B = B.to(original_device).type(original_type) return [A, B] else: raise ValueError("type should be left, right or full") class GaLoreAdamW(Optimizer): """ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). Parameters: params (`Iterable[nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (`float`, *optional*, defaults to 0.001): The learning rate to use. betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): Adam's betas parameters (b1, b2). eps (`float`, *optional*, defaults to 1e-06): Adam's epsilon for numerical stability. weight_decay (`float`, *optional*, defaults to 0.0): Decoupled weight decay to apply. correct_bias (`bool`, *optional*, defaults to `True`): Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). no_deprecation_warning (`bool`, *optional*, defaults to `False`): A flag used to disable the deprecation warning (set to `True` to disable the warning). """ def __init__( self, params: Iterable[nn.parameter.Parameter], lr: float = 1e-3, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-6, weight_decay: float = 0.0, correct_bias: bool = True, no_deprecation_warning: bool = False, ): if not no_deprecation_warning: warnings.warn( "This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch" " implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this" " warning", FutureWarning, ) require_version("torch>=1.5.0") # add_ with alpha if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") if not 0.0 <= betas[0] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") if not 0.0 <= betas[1] < 1.0: raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") if not 0.0 <= eps: raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} super().__init__(params, defaults) @torch.no_grad() def step(self, closure: Callable = None): """ Performs a single optimization step. Arguments: closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") state = self.state[p] if "step" not in state: state["step"] = 0 # GaLore Projection if "rank" in group: if "projector" not in state: state["projector"] = GaLoreProjector( group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], ) grad = state["projector"].project(grad, state["step"]) # State initialization if "exp_avg" not in state: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(grad) exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = group["betas"] state["step"] += 1 # Decay the first and second moment running average coefficient # In-place operations to update the averages at the same time exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) step_size = group["lr"] if group["correct_bias"]: # No bias correction for Bert bias_correction1 = 1.0 - beta1 ** state["step"] bias_correction2 = 1.0 - beta2 ** state["step"] step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 # compute norm gradient norm_grad = exp_avg / denom # GaLore Projection Back if "rank" in group: norm_grad = state["projector"].project_back(norm_grad) p.add_(norm_grad, alpha=-step_size) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want to decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. # Add weight decay at the end (fixed version) if group["weight_decay"] > 0.0: p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) return loss class GaLoreAdafactor(Optimizer): """ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and `relative_step=False`. Arguments: params (`Iterable[nn.parameter.Parameter]`): Iterable of parameters to optimize or dictionaries defining parameter groups. lr (`float`, *optional*): The external learning rate. eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`): Regularization constants for square gradient and parameter scale respectively clip_threshold (`float`, *optional*, defaults to 1.0): Threshold of root mean square of final gradient update decay_rate (`float`, *optional*, defaults to -0.8): Coefficient used to compute running averages of square beta1 (`float`, *optional*): Coefficient used for computing running averages of gradient weight_decay (`float`, *optional*, defaults to 0.0): Weight decay (L2 penalty) scale_parameter (`bool`, *optional*, defaults to `True`): If True, learning rate is scaled by root mean square relative_step (`bool`, *optional*, defaults to `True`): If True, time-dependent learning rate is computed instead of external learning rate warmup_init (`bool`, *optional*, defaults to `False`): Time-dependent learning rate computation depends on whether warm-up initialization is being used This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): - Training without LR warmup or clip_threshold is not recommended. - use scheduled LR warm-up to fixed LR - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) - Disable relative updates - Use scale_parameter=False - Additional optimizer operations like gradient clipping should not be used alongside Adafactor Example: ```python Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) ``` Others reported the following combination to work well: ```python Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) ``` When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] scheduler as following: ```python from transformers.optimization import Adafactor, AdafactorSchedule optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) lr_scheduler = AdafactorSchedule(optimizer) trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) ``` Usage: ```python # replace AdamW with Adafactor optimizer = Adafactor( model.parameters(), lr=1e-3, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, relative_step=False, scale_parameter=False, warmup_init=False, ) ```""" # make default to be the same as trainer def __init__( self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=False, relative_step=False, warmup_init=False, ): # scale_parameter=True, # relative_step=True, require_version("torch>=1.5.0") # add_ with alpha if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: raise ValueError("`warmup_init=True` requires `relative_step=True`") defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, "decay_rate": decay_rate, "beta1": beta1, "weight_decay": weight_decay, "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, } super().__init__(params, defaults) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) return param_scale * rel_step_sz @staticmethod def _get_options(param_group, param_shape): factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): # copy from fairseq's adafactor implementation: # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.dtype in {torch.float16, torch.bfloat16}: grad = grad.float() if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] if "step" not in state: state["step"] = 0 # GaLore Projection if "rank" in group: if "projector" not in state: state["projector"] = GaLoreProjector( group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], ) grad = state["projector"].project(grad, state["step"]) grad_shape = grad.shape factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if "RMS" not in state: state["step"] = 0 if use_first_moment: # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad) if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 else: if use_first_moment: state["exp_avg"] = state["exp_avg"].to(grad) if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p if p.dtype in {torch.float16, torch.bfloat16}: p_data_fp32 = p_data_fp32.float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg # GaLore Projection Back if "rank" in group: update = state["projector"].project_back(update) if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) p_data_fp32.add_(-update) if p.dtype in {torch.float16, torch.bfloat16}: p.copy_(p_data_fp32) return loss try: from bitsandbytes.optim.optimizer import Optimizer2State except ImportError: # define a dummy Optimizer2State class class Optimizer2State(Optimizer): def __init__(self, *args, **kwargs): raise ImportError("Please install bitsandbytes to use this optimizer") def step(self, *args, **kwargs): pass def prefetch_state(self, *args, **kwargs): pass def init_state(self, *args, **kwargs): pass def update_step(self, *args, **kwargs): pass def check_overrides(self, *args, **kwargs): pass def to_gpu(self, *args, **kwargs): pass def to_cpu(self, *args, **kwargs): pass class GaLoreAdamW8bit(Optimizer2State): def __init__( self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False, ): super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() overflows = [] if not self.initialized: self.check_overrides() self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True # if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: continue state = self.state[p] if "step" not in state: state["step"] = 0 # GaLore Projection if "rank" in group: if "projector" not in state: state["projector"] = GaLoreProjector( group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"], ) if "weight_decay" in group and group["weight_decay"] > 0: # ensure that the weight decay is not applied to the norm grad group["weight_decay_saved"] = group["weight_decay"] group["weight_decay"] = 0 grad = state["projector"].project(p.grad, state["step"]) # suboptimal implementation p.saved_data = p.data.clone() p.data = grad.clone().to(p.data.dtype).to(p.data.device) p.data.zero_() p.grad = grad if "state1" not in state: self.init_state(group, p, gindex, pindex) self.prefetch_state(p) self.update_step(group, p, gindex, pindex) torch.cuda.synchronize() # GaLore Projection Back if "rank" in group: p.data = p.saved_data.add_(state["projector"].project_back(p.data)) # apply weight decay if "weight_decay_saved" in group: p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"]) group["weight_decay"] = group["weight_decay_saved"] del group["weight_decay_saved"] if self.is_paged: # all paged operation are asynchronous, we need # to sync to make sure all tensors are in the right state torch.cuda.synchronize() return loss def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes): # trainable_params is list of dict, each dict contains "params" and "lr" # list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2] # block lr is not supported assert len(trainable_params) == len(training_models), "block lr is not supported" lr = args.learning_rate optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: for arg in args.optimizer_args: key, value = arg.split("=") value = ast.literal_eval(value) optimizer_kwargs[key] = value rank = optimizer_kwargs.pop("rank", 128) update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50) galore_scale = optimizer_kwargs.pop("galore_scale", 1.0) proj_type = optimizer_kwargs.pop("proj_type", "std") weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer # make parameters with "rank" to a single group, if param_name has "mlp" or "attn" # target_modules_list = ["attn", "mlp"] target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net param_groups = [] param_lr = {} for model, params in zip(training_models, trainable_params): logger.info(f"model: {model.__class__.__name__}") galore_params = [] group_lr = params.get("lr", lr) for module_name, module in model.named_modules(): if not isinstance(module, nn.Linear): continue if not any(target_key in module_name for target_key in target_modules_list): continue logger.info("enable GaLore for weights in module: " + module_name) galore_params.append(module.weight) id_galore_params = [id(p) for p in galore_params] # make parameters without "rank" to another group regular_params = [p for p in params["params"] if id(p) not in id_galore_params] # then call galore_adamw param_groups.append({"params": regular_params, "lr": group_lr}) param_groups.append( { "params": galore_params, "rank": rank, "update_proj_gap": update_proj_gap, "scale": galore_scale, "proj_type": proj_type, "lr": group_lr, } ) # record lr for p in regular_params + galore_params: param_lr[id(p)] = group_lr # select optimizer scheduler = None if optimizer_type == "galore_adamw": optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs) elif optimizer_type == "galore_adafactor": beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1") optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs) elif optimizer_type == "galore_adamw8bit": optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs) elif optimizer_type == "galore_adamw8bit_per_layer": # TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap optimizer_dict = {} all_params = [] for params in trainable_params: all_params.extend(params["params"]) for p in all_params: if p.requires_grad: if id(p) in id_galore_params: optimizer_dict[p] = GaLoreAdamW8bit( [ { "params": [p], "rank": rank, "update_proj_gap": update_proj_gap * 2, "scale": galore_scale, "proj_type": proj_type, } ], lr=param_lr[id(p)], weight_decay=weight_decay, ) else: import bitsandbytes as bnb optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay) # get scheduler dict # scheduler needs accelerate.prepare? scheduler_dict = {} for p in all_params: if p.requires_grad: scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes) def optimizer_hook(p): if p.grad is None: return optimizer_dict[p].step() optimizer_dict[p].zero_grad() scheduler_dict[p].step() # Register the hook onto every parameter for p in all_params: if p.requires_grad: p.register_post_accumulate_grad_hook(optimizer_hook) # make dummy scheduler and optimizer class DummyScheduler: def step(self): pass class DummyOptimizer: def __init__(self, optimizer_dict): self.optimizer_dict = optimizer_dict def step(self): pass def zero_grad(self, set_to_none=False): pass scheduler = DummyScheduler(optimizer_dict[all_params[0]]) optimizer = DummyOptimizer(optimizer_dict) else: raise ValueError(f"Unsupported optimizer type: {optimizer_type}") return optimizer, scheduler