diff --git a/networks/lora.py b/networks/lora.py index cfc517ce..6e860a03 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -514,17 +514,25 @@ class LoRANetwork(torch.nn.Module): all_params.append(param_data) if self.unet_loras: - for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + if self.stratified_lr: + for lora in self.unet_loras: + param_data = {'params': lora.parameters()} + if unet_lr is not None: + param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + elif default_lr is not None: + param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) + if ('lr' in param_data) and (param_data['lr']==0): + continue + all_params.append(param_data) + else: + params = [] + for lora in self.unet_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if unet_lr is not None: param_data['lr'] = unet_lr - elif default_lr is not None: - param_data['lr'] = default_lr - if self.stratified_lr and ('lr' in param_data): - param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) - if (param_data['lr']==0): - continue all_params.append(param_data) + return all_params def enable_gradient_checkpointing(self):