diff --git a/library/lora_util.py b/library/lora_util.py index d76cd17b..4676b647 100644 --- a/library/lora_util.py +++ b/library/lora_util.py @@ -14,7 +14,7 @@ def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lor weight = org_module.weight.data.to(device, dtype=torch.float32) - with torch.autocast(device.type, dtype=torch.float32): + with torch.autocast(device.type): # SVD decomposition V, S, Uh = torch.linalg.svd(weight, full_matrices=False) @@ -58,7 +58,7 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo weight = org_module.weight.data.clone().to(device, dtype=torch.float32) - with torch.autocast(device.type, dtype=torch.float32): + with torch.autocast(device.type): # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, V, S, Uh = torch.linalg.svd(weight, full_matrices=False) Vr = V[:, : rank] diff --git a/networks/lora_flux.py b/networks/lora_flux.py index e6b9b95b..dcd0541e 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -7,16 +7,13 @@ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py -import math import os -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel -import numpy as np import torch from torch import Tensor from tqdm import tqdm -import re from library.utils import setup_logging from library.lora_util import initialize_lora, initialize_pissa, initialize_urae @@ -80,7 +77,7 @@ class LoRAModule(torch.nn.Module): ) self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - with torch.autocast(org_module.weight.device.type), torch.no_grad(): + with torch.autocast("cuda"), torch.no_grad(): self.initialize_weights(org_module) # same as microsoft's @@ -99,7 +96,7 @@ class LoRAModule(torch.nn.Module): self._org_lora_up = self.lora_up.weight.data.detach().clone() self._org_lora_down = self.lora_down.weight.data.detach().clone() elif self.initialize == "pissa": - initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) + initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=torch.device("cuda")) # Need to store the original weights so we can get a plain LoRA out self._org_lora_up = self.lora_up.weight.data.detach().clone() self._org_lora_down = self.lora_down.weight.data.detach().clone() @@ -115,7 +112,7 @@ class LoRAModule(torch.nn.Module): self._org_lora_up = lora_up.weight.data.detach().clone() self._org_lora_down = lora_down.weight.data.detach().clone() elif self.initialize == "pissa": - initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim) + initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=torch.device("cuda")) # Need to store the original weights so we can get a plain LoRA out self._org_lora_up = lora_up.weight.data.detach().clone() self._org_lora_down = lora_down.weight.data.detach().clone() @@ -1090,7 +1087,7 @@ class LoRANetwork(torch.nn.Module): state_dict = self.state_dict() if self.initialize in ['pissa']: - loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): # Calculate ΔW = A'B' - AB delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)