From 19340d82e6fb2a081cadb5fc4c6f38aa627ea81d Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 12:57:55 +0900 Subject: [PATCH] =?UTF-8?q?=E5=B1=A4=E5=88=A5=E5=AD=A6=E7=BF=92=E7=8E=87?= =?UTF-8?q?=E3=82=92=E4=BD=BF=E3=82=8F=E3=81=AA=E3=81=84=E5=A0=B4=E5=90=88?= =?UTF-8?q?=E3=81=ABparams=E3=82=92=E3=81=BE=E3=81=A8=E3=82=81=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index cfc517ce..6e860a03 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module): all_params.append(param_data) if self.unet_loras: - for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + if self.stratified_lr: + for lora in self.unet_loras: + param_data = {'params': lora.parameters()} + if unet_lr is not None: + param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + elif default_lr is not None: + param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) + if ('lr' in param_data) and (param_data['lr']==0): + continue + all_params.append(param_data) + else: + params = [] + for lora in self.unet_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if unet_lr is not None: param_data['lr'] = unet_lr - elif default_lr is not None: - param_data['lr'] = default_lr - if self.stratified_lr and ('lr' in param_data): - param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) - if (param_data['lr']==0): - continue all_params.append(param_data) + return all_params def enable_gradient_checkpointing(self):